Skip to content

Commit

Permalink
Add finalization step for cov computations
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jun 20, 2024
1 parent 9f1bbf5 commit 68f5c71
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 170 deletions.
7 changes: 4 additions & 3 deletions kronfluence/factor/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
set_gradient_scale,
set_mode,
synchronize_covariance_matrices,
update_factor_args,
update_factor_args, finalize_covariance_matrices,
)
from kronfluence.task import Task
from kronfluence.utils.constants import (
Expand Down Expand Up @@ -144,7 +144,9 @@ def fit_covariance_matrices_with_loader(
if enable_amp:
gradient_scale = 1.0 / scaler.get_scale()
set_gradient_scale(model=model, gradient_scale=gradient_scale)
original_model = None
if factor_args.compile_mode is not None:
original_model = model
model = torch.compile(model, mode=factor_args.compile_mode)

with tqdm(
Expand Down Expand Up @@ -185,6 +187,7 @@ def fit_covariance_matrices_with_loader(
pbar.update(1)

with torch.no_grad():
finalize_covariance_matrices(model=model)
if state.use_distributed:
# Aggregate covariance matrices across multiple devices or nodes.
synchronize_covariance_matrices(model=model)
Expand All @@ -198,8 +201,6 @@ def fit_covariance_matrices_with_loader(

# Clean up the memory.
model.zero_grad(set_to_none=True)
remove_attention_mask(model=model)
remove_gradient_scale(model=model)
set_mode(model=model, mode=ModuleMode.DEFAULT, keep_factors=False)

return num_data_processed, saved_factors
Loading

0 comments on commit 68f5c71

Please sign in to comment.