Skip to content

Commit

Permalink
Change output directory name
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jun 22, 2024
1 parent 80513bc commit b79206b
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 39 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ cython_debug/
# Checkpoints and influence outputs
checkpoints/
analyses/
influence_results/
data/
*.pth
*.pt
2 changes: 1 addition & 1 deletion kronfluence/computer/score_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def executable_batch_size_func(batch_size: int) -> None:
score_args=score_args,
factor_args=factor_args,
tracked_module_names=tracked_modules_name,
disable_tqdm=True
disable_tqdm=True,
)

per_device_batch_size = find_executable_batch_size(
Expand Down
3 changes: 2 additions & 1 deletion kronfluence/factor/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
from kronfluence.arguments import FactorArguments
from kronfluence.module.tracked_module import ModuleMode
from kronfluence.module.utils import (
get_tracked_module_names,
load_factors,
remove_gradient_scale,
set_attention_mask,
set_gradient_scale,
set_mode,
synchronize_covariance_matrices,
update_factor_args, get_tracked_module_names,
update_factor_args,
)
from kronfluence.task import Task
from kronfluence.utils.constants import (
Expand Down
3 changes: 2 additions & 1 deletion kronfluence/factor/eigen.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
set_gradient_scale,
set_mode,
synchronize_lambda_matrices,
update_factor_args, update_aggregated_lambda_matrices,
update_aggregated_lambda_matrices,
update_factor_args,
)
from kronfluence.task import Task
from kronfluence.utils.constants import (
Expand Down
9 changes: 2 additions & 7 deletions kronfluence/module/tracked_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,9 +877,7 @@ def backward_hook(output_gradient: torch.Tensor) -> None:

# The preconditioning factors need to be loaded to appropriate device as they will be
# used at each iteration.
if not self._storge_at_current_device:
self._move_storage_to_device(target_device=per_sample_gradient.device)
self._storge_at_current_device = True
self._move_storage_to_device(target_device=per_sample_gradient.device)

if self._cached_per_sample_gradient is None:
self._cached_per_sample_gradient = per_sample_gradient
Expand Down Expand Up @@ -941,9 +939,7 @@ def backward_hook(output_gradient: torch.Tensor) -> None:

# The preconditioning factors need to be loaded to appropriate device as they will be
# used at each iteration.
if not self._storge_at_current_device:
self._move_storage_to_device(target_device=per_sample_gradient.device)
self._storge_at_current_device = True
self._move_storage_to_device(target_device=per_sample_gradient.device)

if self._cached_per_sample_gradient is None:
self._cached_per_sample_gradient = per_sample_gradient
Expand Down Expand Up @@ -996,4 +992,3 @@ def release_scores(self) -> None:
self._cached_activations = []
del self._cached_per_sample_gradient
self._cached_per_sample_gradient = None
self._storge_at_current_device = False
4 changes: 2 additions & 2 deletions kronfluence/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
@dataclass
class DataLoaderKwargs(KwargsHandler):
"""The object used to customize `DataLoader`. Please refer to https://pytorch.org/docs/stable/data.html for
detailed information of each argument. The default arguments are copied from PyTorch version 2.2.
detailed information of each argument. The default arguments are copied from PyTorch version 2.3.
"""

num_workers: int = 0
Expand Down Expand Up @@ -115,7 +115,7 @@ class DistributedSamplerWithStack(Sampler[T_co]):
"""DistributedSampleWithStack is different from `DistributedSampler`. Instead of subsampling,
it stacks the dataset. For example, when `num_replicas` is 3, and the dataset of [0, ..., 9] is given,
the first, second, and third rank should have [0, 1, 2], [3, 4, 5], and [6, 7, 8], respectively. However,
it still adds extra samples to make the dataset evenly divisible.
it still adds extra samples to make the dataset evenly divisible (different from DistributedEvalSampler).
"""

def __init__( # pylint: disable=super-init-not-called
Expand Down
6 changes: 4 additions & 2 deletions kronfluence/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def log_row(action: str, mean: str, num_calls: str, total: str, per: str) -> str


class PassThroughProfiler(Profiler):
"""A pass through Profiler objective that does not record timing for the profiler."""
"""A pass through Profiler objective that does not record timing."""

def start(self, action_name: str) -> None:
"""Defines how to start recording an action."""
Expand All @@ -161,6 +161,8 @@ def summary(self) -> str:
class TorchProfiler(Profiler):
"""A PyTorch Profiler objective that provides detailed profiling information:
https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html.
This is useful for low-level profiling in PyTorch, and is not used by default.
"""

def __init__(self, state: State) -> None:
Expand All @@ -174,7 +176,7 @@ def start(self, action_name: str) -> None:
"""Defines how to start recording an action."""
if action_name in self.current_actions:
raise ValueError(f"Attempted to start {action_name} which has already started.")
# Set dummy value, since only used to track duplicate actions
# Set dummy value, since only used to track duplicate actions.
self.current_actions[action_name] = 0.0
self.actions.append(action_name)
self._torch_prof.start()
Expand Down
58 changes: 33 additions & 25 deletions kronfluence/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,26 @@


def apply_ddp(
model: torch.nn.Module,
model: nn.Module,
local_rank: int,
rank: int,
world_size: int,
) -> DistributedDataParallel:
"""
Applies DistributedDataParallel (DDP) to the given model.
"""Applies DistributedDataParallel (DDP) to the given model.
Args:
model (torch.nn.Module): The model to apply DDP to.
local_rank (int): The local rank of the current process.
rank (int): The rank of the current process.
world_size (int): The total number of processes.
model (nn.Module):
The model for which DDP will be applied.
local_rank (int):
The local rank of the current process.
rank (int):
The rank of the current process.
world_size (int):
The total number of processes.
Returns:
DistributedDataParallel: The model wrapped with DDP.
DistributedDataParallel:
The model wrapped with DDP.
"""
dist.init_process_group("nccl", rank=rank, world_size=world_size)
device = torch.device(f"cuda:{local_rank}")
Expand All @@ -48,7 +52,7 @@ def apply_ddp(


def apply_fsdp(
model: torch.nn.Module,
model: nn.Module,
local_rank: int,
rank: int,
world_size: int,
Expand All @@ -57,27 +61,31 @@ def apply_fsdp(
is_transformer: bool = False,
layer_to_wrap: Optional[nn.Module] = None,
) -> FSDP:
"""
Applies FullyShardedDataParallel (FSDP) to the given model.
"""Applies FullyShardedDataParallel (FSDP) to the given model.
Args:
model (torch.nn.Module): The model to apply FSDP to.
local_rank (int): The local rank of the current process.
rank (int): The rank of the current process.
world_size (int): The total number of processes.
sharding_strategy (str): The sharding strategy to use.
Defaults to "FULL_SHARD".
cpu_offload (bool): Whether to offload parameters to CPU. Check
https://pytorch.org/docs/2.2/fsdp.html#torch.distributed.fsdp.CPUOffload.
Defaults to True.
is_transformer (bool): Whether the model is a transformer model.
Defaults to False.
layer_to_wrap (nn.Module, optional): The specific layer to wrap
for transformer models. Required if `is_transformer` is True.
model (nn.Module):
The model for which FSDP will be applied.
local_rank (int):
The local rank of the current process.
rank (int):
The rank of the current process.
world_size (int):
The total number of processes.
sharding_strategy (str):
The sharding strategy to use. Defaults to "FULL_SHARD".
cpu_offload (bool):
Whether to offload parameters to CPU. Check
https://pytorch.org/docs/2.2/fsdp.html#torch.distributed.fsdp.CPUOffload. Defaults to True.
is_transformer (bool):
Whether the model is a transformer model. Defaults to False.
layer_to_wrap (nn.Module, optional):
The specific layer to wrap for transformer models. Required if `is_transformer` is True.
Defaults to None.
Returns:
FSDP: The model wrapped with FSDP.
FullyShardedDataParallel:
The model wrapped with FSDP.
"""
dist.init_process_group("nccl", rank=rank, world_size=world_size)
device = torch.device(f"cuda:{local_rank}")
Expand Down

0 comments on commit b79206b

Please sign in to comment.