Skip to content

Commit

Permalink
Merge pull request #24 from pomonam/cleanup_factor
Browse files Browse the repository at this point in the history
Support gradient checkpointing
  • Loading branch information
pomonam committed Jun 23, 2024
2 parents 883cf28 + 293e3d2 commit 2b7dbd3
Show file tree
Hide file tree
Showing 30 changed files with 1,233 additions and 909 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ cython_debug/
# Checkpoints and influence outputs
checkpoints/
analyses/
influence_results/
data/
*.pth
*.pt
27 changes: 15 additions & 12 deletions kronfluence/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def prepare_model(
task: Task,
) -> nn.Module:
"""Prepares the model before passing it to `Analyzer`. This function sets `param.requires_grad = False`
for all modules that does not require influence computations and installs `TrackedModule` to supported
modules. This `TrackedModule` keeps track of relevant statistics needed to compute influence scores.
for all modules and installs `TrackedModule` to supported modules. This `TrackedModule` keeps track of relevant
statistics needed to compute influence scores.
Args:
model (nn.Module):
Expand All @@ -32,22 +32,21 @@ def prepare_model(
Returns:
nn.Module:
The same PyTorch model with `param.requires_grad = False` on all modules that does not require influence
computations and with `TrackedModule` installed.
The PyTorch model with `param.requires_grad = False` on all modules and with `TrackedModule` installed.
"""
model.eval()
for params in model.parameters():
params.requires_grad = False
for buffers in model.buffers():
buffers.requires_grad = False
# Install `TrackedModule` to the model.
model = wrap_tracked_modules(model=model, task=task)
return model


class Analyzer(FactorComputer, ScoreComputer):
"""
Handles the computation of all factors (e.g., covariance and Lambda matrices for EKFAC)
and influence scores for a given PyTorch model.
"""
"""Handles the computation of all factors (e.g., covariance and Lambda matrices for EKFAC)
and influence scores for a given PyTorch model."""

def __init__(
self,
Expand All @@ -58,7 +57,8 @@ def __init__(
log_level: Optional[int] = None,
log_main_process_only: bool = True,
profile: bool = False,
output_dir: str = "./analyses",
disable_tqdm: bool = False,
output_dir: str = "./influence_results",
disable_model_save: bool = True,
) -> None:
"""Initializes an instance of the Analyzer class.
Expand All @@ -80,9 +80,11 @@ def __init__(
profile (bool, optional):
Enables the generation of performance profiling logs. This can be useful for
identifying bottlenecks or performance issues. Defaults to False.
disable_tqdm (bool, optional):
Disables TQDM progress bars. Defaults to False.
output_dir (str):
The file path to the directory, where analysis results will be stored. If the directory
does not exist, it will be created. Defaults to './analyses'.
does not exist, it will be created. Defaults to './influence_results'.
disable_model_save (bool, optional):
If set to True, prevents the saving of the model's state_dict. When the provided model is different
from the previously saved model, it will raise an Exception. Defaults to True.
Expand All @@ -95,12 +97,13 @@ def __init__(
log_level=log_level,
log_main_process_only=log_main_process_only,
profile=profile,
disable_tqdm=disable_tqdm,
output_dir=output_dir,
)
self.logger.info(f"Initializing Computer with parameters: {locals()}")
self.logger.debug(f"Process state configuration:\n{repr(self.state)}")

# Save model parameters.
# Saves model parameters.
if self.state.is_main_process and not disable_model_save:
self._save_model()
self.state.wait_for_everyone()
Expand Down Expand Up @@ -151,7 +154,7 @@ def fit_all_factors(
) -> None:
"""Computes all necessary factors for the given factor strategy. As an example, EK-FAC
requires (1) computing covariance matrices, (2) performing Eigendecomposition, and
(3) computing Lambda (corrected-eigenvalues) matrices.
(3) computing Lambda (corrected eigenvalues) matrices.
Args:
factors_name (str):
Expand Down
12 changes: 10 additions & 2 deletions kronfluence/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ class FactorArguments(Arguments):
default=None,
metadata={"help": "Dtype for automatic mixed precision (AMP). Disables AMP if None."},
)
shared_parameters_exist: bool = field(
default=False,
metadata={"help": "Specifies whether the shared parameters exist in the forward pass."},
)

# Configuration for fitting covariance matrices. #
covariance_max_examples: Optional[int] = field(
Expand Down Expand Up @@ -125,6 +129,10 @@ class FactorArguments(Arguments):
default=False,
metadata={"help": "Whether to offload cached activation to CPU for computing the per-sample-gradient."},
)
per_sample_gradient_dtype: torch.dtype = field(
default=torch.float32,
metadata={"help": "Dtype for computing per-sample-gradients."},
)
lambda_dtype: torch.dtype = field(
default=torch.float32,
metadata={"help": "Dtype for computing Lambda (corrected eigenvalues) matrices."},
Expand Down Expand Up @@ -188,9 +196,9 @@ class ScoreArguments(Arguments):
default=None,
metadata={"help": "Rank for the query gradient. Does not apply low-rank approximation if None."},
)
num_query_gradient_aggregations: int = field(
num_query_gradient_accumulations: int = field(
default=1,
metadata={"help": "Number of query batches to aggregate over."},
metadata={"help": "Number of query batches to accumulate over."},
)
use_measurement_for_self_influence: bool = field(
default=False,
Expand Down
38 changes: 22 additions & 16 deletions kronfluence/computer/computer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import time
from abc import ABC
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -38,6 +39,7 @@
TrackedModuleNotFoundError,
)
from kronfluence.utils.logger import PassThroughProfiler, Profiler, get_logger
from kronfluence.utils.model import apply_ddp
from kronfluence.utils.save import (
FACTOR_ARGUMENTS_NAME,
FACTOR_SAVE_PREFIX,
Expand All @@ -62,11 +64,12 @@ def __init__(
log_level: Optional[int] = logging.INFO,
log_main_process_only: bool = True,
profile: bool = False,
disable_tqdm: bool = False,
) -> None:
"""Initializes an instance of the Computer class."""
self.state = State(cpu=cpu)

# Create and configure logger.
# Creates and configures logger.
disable_log = log_main_process_only and self.state.process_index != 0
self.logger = get_logger(name=__name__, log_level=log_level, disable_log=disable_log)

Expand All @@ -91,10 +94,11 @@ def __init__(
)
self.logger.warning(warning_msg)
self.model.to(self.state.device)
self.model = DDP(
self.model,
device_ids=[self.state.local_process_index],
output_device=self.state.local_process_index,
self.model = apply_ddp(
model=self.model,
local_rank=self.state.local_process_index,
rank=self.state.process_index,
world_size=self.state.num_processes,
)

if cpu and isinstance(model, (DataParallel, DDP, FSDP)):
Expand All @@ -105,17 +109,18 @@ def __init__(
if not self.state.use_distributed:
self.model.to(self.state.device)

# Create and configure output directory.
# Creates and configures output directory.
self.output_dir = Path(output_dir).joinpath(name).resolve()
os.makedirs(name=self.output_dir, exist_ok=True)

# Create and configure profiler.
# Creates and configures profiler.
self.profiler = Profiler(state=self.state) if profile else PassThroughProfiler(state=self.state)
# Create directory to save profiler output.
# Creates directory to save profiler output.
self.profiler_dir = (self.output_dir / "profiler_output").resolve()
os.makedirs(name=self.profiler_dir, exist_ok=True)
self.disable_tqdm = disable_tqdm

# DataLoader parameters.
# Sets PyTorch DataLoader arguments.
self._dataloader_params = DataLoaderKwargs()

def factors_output_dir(self, factors_name: str) -> Path:
Expand Down Expand Up @@ -169,7 +174,7 @@ def _save_dataset_metadata(

if dataset_metadata_save_path.exists() and not overwrite_output_dir:
self.logger.info(f"Found existing saved dataset metadata at `{dataset_metadata_save_path}`.")
# Load the existing dataset metadata for comparison.
# Loads the existing dataset metadata for comparison.
loaded_metadata = load_json(dataset_metadata_save_path)
if loaded_metadata != dataset_metadata:
error_msg = (
Expand Down Expand Up @@ -318,9 +323,10 @@ def _get_module_partition(
return modules_partition_list, target_module_partitions

def _log_profile_summary(self) -> None:
"""Log the summary of the profiling results."""
"""Saves the summary of the profiling results."""
profile_summary = self.profiler.summary()
profile_save_path = (self.profiler_dir / f"summary_rank_{self.state.process_index}.txt").resolve()
time_str = time.strftime("%Y%m%d_%H%M%S")
profile_save_path = (self.profiler_dir / f"summary_rank_{self.state.process_index}_{time_str}.txt").resolve()
if profile_summary != "":
self.logger.info(profile_summary)
with open(profile_save_path, "a", encoding="utf-8") as f:
Expand Down Expand Up @@ -378,7 +384,7 @@ def load_self_scores(self, scores_name: str) -> Optional[SCORE_TYPE]:
return None

def load_all_factors(self, factors_name: str) -> FACTOR_TYPE:
"""Loads factors from disk."""
"""Loads all relevant factors from disk."""
from kronfluence.factor.config import ( # pylint: disable=import-outside-toplevel
FactorConfig,
)
Expand All @@ -396,7 +402,7 @@ def load_all_factors(self, factors_name: str) -> FACTOR_TYPE:
covariance_factors = self.load_covariance_matrices(factors_name=factors_name)
if covariance_factors is None:
error_msg = (
f"Strategy `{factor_args.strategy}` computing covariance matrices. "
f"Strategy `{factor_args.strategy}` requires covariance matrices. "
f"However, the covariance matrices were not found."
)
self.logger.error(error_msg)
Expand All @@ -407,7 +413,7 @@ def load_all_factors(self, factors_name: str) -> FACTOR_TYPE:
eigen_factors = self.load_eigendecomposition(factors_name=factors_name)
if eigen_factors is None:
error_msg = (
f"Strategy `{factor_args.strategy}` computing Eigendecomposition. "
f"Strategy `{factor_args.strategy}` requires Eigendecomposition results. "
f"However, the Eigendecomposition results were not found."
)
self.logger.error(error_msg)
Expand All @@ -418,7 +424,7 @@ def load_all_factors(self, factors_name: str) -> FACTOR_TYPE:
lambda_factors = self.load_lambda_matrices(factors_name=factors_name)
if lambda_factors is None:
error_msg = (
f"Strategy `{factor_args.strategy}` computing Lambda matrices. "
f"Strategy `{factor_args.strategy}` requires Lambda matrices. "
f"However, the Lambda matrices were not found."
)
self.logger.error(error_msg)
Expand Down
13 changes: 9 additions & 4 deletions kronfluence/computer/factor_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class FactorComputer(Computer):
def _configure_and_save_factor_args(
self, factor_args: Optional[FactorArguments], factors_output_dir: Path, overwrite_output_dir: bool
) -> FactorArguments:
"""Configure the provided factor arguments and save it in disk."""
"""Configures the provided factor arguments and saves it in disk."""
if factor_args is None:
factor_args = FactorArguments()
self.logger.info(f"Factor arguments not provided. Using the default configuration: {factor_args}.")
Expand Down Expand Up @@ -139,9 +139,9 @@ def _find_executable_factors_batch_size(

def executable_batch_size_func(batch_size: int) -> None:
self.logger.info(f"Attempting to set per-device batch size to {batch_size}.")
# Release all memory that could be caused by the previous OOM.
set_mode(model=self.model, mode=ModuleMode.DEFAULT, keep_factors=False)
# Releases all memory that could be caused by the previous OOM.
self.model.zero_grad(set_to_none=True)
set_mode(model=self.model, mode=ModuleMode.DEFAULT, keep_factors=False)
release_memory()
total_batch_size = batch_size * self.state.num_processes
loader = self._get_dataloader(
Expand Down Expand Up @@ -293,6 +293,7 @@ def fit_covariance_matrices(
"task": self.task,
"factor_args": factor_args,
"tracked_module_names": module_partition_names[module_partition],
"disable_tqdm": True,
}
per_device_batch_size = self._find_executable_factors_batch_size(
func=fit_covariance_matrices_with_loader,
Expand Down Expand Up @@ -320,6 +321,7 @@ def fit_covariance_matrices(
loader=loader,
factor_args=factor_args,
tracked_module_names=module_partition_names[module_partition],
disable_tqdm=self.disable_tqdm,
)
end_time = get_time(state=self.state)
elapsed_time = end_time - start_time
Expand Down Expand Up @@ -437,7 +439,7 @@ def perform_eigendecomposition(
covariance_factors = load_covariance_matrices(output_dir=load_factors_output_dir)

if load_from_factors_name is not None and self.state.is_main_process:
# Save the loaded covariances to the current factor output directory.
# Saves the loaded covariances to the current factor output directory.
with self.profiler.profile("Save Covariance"):
save_covariance_matrices(output_dir=factors_output_dir, factors=covariance_factors)
loaded_factor_args = self.load_factor_args(factors_name=load_from_factors_name)
Expand All @@ -459,6 +461,7 @@ def perform_eigendecomposition(
model=self.model,
state=self.state,
factor_args=factor_args,
disable_tqdm=self.disable_tqdm,
)
end_time = time.time()
elapsed_time = end_time - start_time
Expand Down Expand Up @@ -642,6 +645,7 @@ def fit_lambda_matrices(
"task": self.task,
"factor_args": factor_args,
"tracked_module_names": module_partition_names[module_partition],
"disable_tqdm": True,
}
per_device_batch_size = self._find_executable_factors_batch_size(
func=fit_lambda_matrices_with_loader,
Expand Down Expand Up @@ -670,6 +674,7 @@ def fit_lambda_matrices(
loader=loader,
factor_args=factor_args,
tracked_module_names=module_partition_names[module_partition],
disable_tqdm=self.disable_tqdm,
)
end_time = get_time(state=self.state)
elapsed_time = end_time - start_time
Expand Down
14 changes: 10 additions & 4 deletions kronfluence/computer/score_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _configure_and_save_score_args(
factors_name: str,
overwrite_output_dir: bool,
) -> Tuple[FactorArguments, ScoreArguments]:
"""Configure the provided factor arguments and save it in disk."""
"""Configures the provided score arguments and saves it in disk."""
if score_args is None:
score_args = ScoreArguments()
self.logger.info(f"Score arguments not provided. Using the default configuration: {score_args}.")
Expand Down Expand Up @@ -173,7 +173,8 @@ def _find_executable_pairwise_scores_batch_size(

def executable_batch_size_func(batch_size: int) -> None:
self.logger.info(f"Attempting to set per-device batch size to {batch_size}.")
# Release all memory that could be caused by the previous OOM.
# Releases all memory that could be caused by the previous OOM.
self.model.zero_grad(set_to_none=True)
set_mode(model=self.model, mode=ModuleMode.DEFAULT, keep_factors=False)
release_memory()
total_batch_size = batch_size * self.state.num_processes
Expand Down Expand Up @@ -203,6 +204,7 @@ def executable_batch_size_func(batch_size: int) -> None:
train_loader=train_loader,
per_device_query_batch_size=per_device_query_batch_size,
tracked_module_names=tracked_modules_name,
disable_tqdm=True,
)

per_device_batch_size = find_executable_batch_size(
Expand Down Expand Up @@ -403,6 +405,7 @@ def compute_pairwise_scores(
score_args=score_args,
factor_args=factor_args,
tracked_module_names=module_partition_names[module_partition],
disable_tqdm=self.disable_tqdm,
)
end_time = get_time(state=self.state)
elapsed_time = end_time - start_time
Expand Down Expand Up @@ -487,7 +490,8 @@ def _find_executable_self_scores_batch_size(

def executable_batch_size_func(batch_size: int) -> None:
self.logger.info(f"Attempting to set per-device batch size to {batch_size}.")
# Release all memory that could be caused by the previous OOM.
# Releases all memory that could be caused by the previous OOM.
self.model.zero_grad(set_to_none=True)
set_mode(model=self.model, mode=ModuleMode.DEFAULT, keep_factors=False)
release_memory()
total_batch_size = batch_size * self.state.num_processes
Expand All @@ -512,6 +516,7 @@ def executable_batch_size_func(batch_size: int) -> None:
score_args=score_args,
factor_args=factor_args,
tracked_module_names=tracked_modules_name,
disable_tqdm=True,
)

per_device_batch_size = find_executable_batch_size(
Expand All @@ -536,7 +541,7 @@ def compute_self_scores(
overwrite_output_dir: bool = False,
) -> Optional[SCORE_TYPE]:
"""Computes self-influence scores for the given score configuration. As an example,
for T training dataset, the self-influence scores are represented as T-dimensional vector.
for training dataset with T examples, the self-influence scores are represented as T-dimensional vector.
Args:
scores_name (str):
Expand Down Expand Up @@ -691,6 +696,7 @@ def compute_self_scores(
score_args=score_args,
factor_args=factor_args,
tracked_module_names=module_partition_names[module_partition],
disable_tqdm=self.disable_tqdm,
)
end_time = get_time(state=self.state)
elapsed_time = end_time - start_time
Expand Down
Loading

0 comments on commit 2b7dbd3

Please sign in to comment.