Skip to content

Commit

Permalink
Start tracked_module refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jun 19, 2024
1 parent b2d1664 commit 795d454
Showing 1 changed file with 37 additions and 32 deletions.
69 changes: 37 additions & 32 deletions kronfluence/module/tracked_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ class ModuleMode(str, BaseEnum):
SELF_MEASUREMENT_SCORE = "self_measurement_score"


def null(_: Any) -> None:
"""Dummy function that does not perform any operations."""
pass


class TrackedModule(nn.Module):
"""A wrapper class for PyTorch modules to compute preconditioning factors and influence scores."""

Expand Down Expand Up @@ -86,6 +91,9 @@ def __init__(
dtype=torch.float16,
)
)
# Operations that will be performed before and after a forward pass.
self._pre_forward = null
self._post_forward = null

if factor_args is None:
factor_args = FactorArguments()
Expand Down Expand Up @@ -144,11 +152,16 @@ def set_factor(self, factor_name: str, factor: Any) -> None:
def forward(self, inputs: torch.Tensor, *args, **kwargs) -> Any:
"""A forward pass of the tracked module. This should have identical behavior to
the original module."""
return self.original_module(inputs + self._constant, *args, **kwargs)
self._pre_forward(inputs)
outputs = self.original_module(inputs + self._constant, *args, **kwargs)
self._post_forward(outputs)
return outputs

def set_mode(self, mode: ModuleMode, keep_factors: bool = True) -> None:
"""Sets the module mode of all `TrackedModule` instances within a model."""
current_mode = self._mode
self._pre_forward = null
self._post_forward = null
self.remove_registered_hooks()

if current_mode == ModuleMode.COVARIANCE and not keep_factors:
Expand All @@ -172,6 +185,8 @@ def set_mode(self, mode: ModuleMode, keep_factors: bool = True) -> None:
self.release_scores()

if mode == ModuleMode.DEFAULT and not keep_factors:
self._pre_forward = null
self._post_forward = null
# Releases all factors when the mode is set to default.
self.remove_attention_mask()
self._release_covariance_matrices()
Expand All @@ -181,10 +196,12 @@ def set_mode(self, mode: ModuleMode, keep_factors: bool = True) -> None:
self.release_scores()

if mode == ModuleMode.COVARIANCE:
self._register_covariance_hooks()
self._pre_forward = self._covariance_pre_forward
self._post_forward = self._covariance_post_forward

if mode == ModuleMode.LAMBDA:
self._register_lambda_hooks()
self._pre_forward = self._activation_cache_pre_forward
self._post_forward = self._lambda_post_forward

if mode == ModuleMode.PRECONDITION_GRADIENT:
self._register_precondition_gradient_hooks()
Expand Down Expand Up @@ -326,25 +343,19 @@ def _update_gradient_covariance_matrix(self, output_gradient: torch.Tensor) -> N
# Add the current batch's pseudo-gradient covariance to the stored pseudo-gradient covariance matrix.
self._storage[GRADIENT_COVARIANCE_MATRIX_NAME].addmm_(flattened_gradient.t(), flattened_gradient)

def _register_covariance_hooks(self) -> None:
"""Installs forward and backward hooks for computation of the covariance matrices."""

def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: Tuple[torch.Tensor]) -> None:
del module
with torch.no_grad():
# Compute and update activation covariance matrix in the forward pass.
self._update_activation_covariance_matrix(inputs[0].detach())
# Register backward hook to obtain gradient with respect to the output.
self._cached_hooks.append(outputs.register_hook(backward_hook))
@torch.no_grad()
def _covariance_pre_forward(self, inputs: Any) -> Any:
# Compute and update activation covariance matrix in the forward pass.
self._update_activation_covariance_matrix(inputs.detach())

def _covariance_post_forward(self, outputs: Any) -> Any:
@torch.no_grad()
def backward_hook(output_gradient: torch.Tensor) -> None:
handle = self._cached_hooks.pop()
handle.remove()
# Compute and update pseudo-gradient covariance matrix in the backward pass.
self._update_gradient_covariance_matrix(output_gradient.detach())

self._registered_hooks.append(self.register_forward_hook(forward_hook))
# Register backward hook to obtain gradient with respect to the output.
outputs.register_hook(backward_hook)

def _release_covariance_matrices(self) -> None:
"""Clears the stored activation and pseudo-gradient covariance matrices from memory."""
Expand Down Expand Up @@ -483,24 +494,17 @@ def _update_lambda_matrix(self, per_sample_gradient: torch.Tensor) -> None:

self._storage[NUM_LAMBDA_PROCESSED].add_(batch_size)

def _register_lambda_hooks(self) -> None:
"""Installs forward and backward hooks for computation of the Lambda matrices."""

def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: Tuple[torch.Tensor]) -> None:
del module
with torch.no_grad():
cached_activation = inputs[0].detach().to(dtype=self.factor_args.lambda_dtype)
if self.factor_args.cached_activation_cpu_offload:
self._cached_activations.append(cached_activation.cpu())
else:
self._cached_activations.append(cached_activation)
# Register backward hook to obtain gradient with respect to the output.
self._cached_hooks.append(outputs.register_hook(backward_hook))
@torch.no_grad()
def _activation_cache_pre_forward(self, inputs: Any) -> Any:
cached_activation = inputs.detach().to(dtype=self.factor_args.lambda_dtype)
if self.factor_args.cached_activation_cpu_offload:
self._cached_activations.append(cached_activation.cpu())
else:
self._cached_activations.append(cached_activation)

def _lambda_post_forward(self, outputs: torch.Tensor) -> Any:
@torch.no_grad()
def backward_hook(output_gradient: torch.Tensor) -> None:
handle = self._cached_hooks.pop()
handle.remove()
cached_activation = self._cached_activations.pop()
if self.factor_args.cached_activation_cpu_offload:
cached_activation = cached_activation.to(device=output_gradient.device)
Expand All @@ -522,7 +526,8 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
del self._cached_per_sample_gradient
self._cached_per_sample_gradient = None

self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook))
# Register backward hook to obtain gradient with respect to the output.
outputs.register_hook(backward_hook)

def _release_lambda_matrix(self) -> None:
"""Clears the stored Lambda matrix from memory."""
Expand Down

0 comments on commit 795d454

Please sign in to comment.