Skip to content

Commit

Permalink
added new FTS config option frozen_bn_track_running_stats along w…
Browse files Browse the repository at this point in the history
…ith associated testing and documentation.
  • Loading branch information
speediedan committed May 10, 2024
1 parent 365376b commit 9d7014b
Show file tree
Hide file tree
Showing 8 changed files with 571 additions and 52 deletions.
10 changes: 10 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,16 @@ either integers or convertible to integers via ``int()``.
``0`` of the current fine-tuning schedule. This auto-configuration can be disabled if desired by setting
:paramref:`~finetuning_scheduler.fts.FinetuningScheduler.enforce_phase0_params` to ``False``.

.. note::

When freezing ``torch.nn.modules.batchnorm._BatchNorm`` modules, Lightning by default disables
``BatchNorm.track_running_stats``. To override this behavior so that even frozen ``BatchNorm`` layers continue to
have ``track_running_stats`` set to ``True``, set the FTS parameter
:paramref:`~finetuning_scheduler.fts.FinetuningScheduler.frozen_bn_track_running_stats` to ``True``.
Beginning with FTS ``2.4.0``,
:paramref:`~finetuning_scheduler.fts.FinetuningScheduler.frozen_bn_track_running_stats` will default to ``True``


EarlyStopping and Epoch-Driven Phase Transition Criteria
********************************************************

Expand Down
23 changes: 22 additions & 1 deletion src/finetuning_scheduler/fts.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from copy import deepcopy
from pprint import pformat
from typing import Any, Dict, Optional, Sequence, Tuple, Union
from typing_extensions import override

import lightning.pytorch as pl
import torch
Expand Down Expand Up @@ -82,6 +83,12 @@ class FinetuningScheduler(ScheduleImplMixin, ScheduleParsingMixin, CallbackDepMi
:class:`~finetuning_scheduler.fts_supporters.FTSCheckpoint` or
:class:`~finetuning_scheduler.fts_supporters.FTSEarlyStopping` callback instances.
.. note::
While :class:`~finetuning_scheduler.fts.FinetuningScheduler` supports the use of
:external+torch:class:`~torch.distributed.optim.ZeroRedundancyOptimizer`, setting ``overlap_with_ddp`` to
``True`` is not supported because that optimizer mode only supports a single parameter group.
.. note::
While :class:`~finetuning_scheduler.fts.FinetuningScheduler` supports the use of
Expand All @@ -107,6 +114,7 @@ def __init__(
apply_lambdas_new_pgs: bool = False,
logging_level: int = logging.INFO,
enforce_phase0_params: bool = True,
frozen_bn_track_running_stats: bool = False,
):
r"""
Arguments used to define and configure a scheduled fine-tuning training session:
Expand Down Expand Up @@ -229,6 +237,11 @@ def __init__(
and present in the optimizer differs from the parameters specified in phase 0. Only the parameters
included in the optimizer are affected; the choice of optimizer, lr_scheduler etc. remains unaltered.
Defaults to ``True``.
frozen_bn_track_running_stats: When freezing ``torch.nn.modules.batchnorm._BatchNorm`` layers, whether
:class:`~finetuning_scheduler.fts.FinetuningScheduler` should set ``BatchNorm`` ``track_running_stats``
to ``True``. Setting this to ``True`` overrides the the default Lightning behavior that sets
``BatchNorm`` ``track_running_stats`` to ``False`` when freezing ``BatchNorm`` layers. Defaults to
``False`` for backwards compatibility. Default will be ``True`` with FTS >= 2.4.0.
Attributes:
_fts_state: The internal :class:`~finetuning_scheduler.fts.FinetuningScheduler` state.
Expand All @@ -255,7 +268,9 @@ def __init__(
self.allow_untested = allow_untested
self.apply_lambdas_new_pgs = apply_lambdas_new_pgs
self.enforce_phase0_params = enforce_phase0_params
self.frozen_bn_track_running_stats = frozen_bn_track_running_stats
self._has_reinit_schedule = False
self._msg_cache = set()
rz_logger = logging.getLogger("lightning.pytorch.utilities.rank_zero")
rz_logger.setLevel(logging_level)

Expand Down Expand Up @@ -292,6 +307,7 @@ def _supported_strategy_flags() -> Sequence[str]:
# "deepspeed", # relevant FTS strategy adapter not yet available, PRs welcome!
)

@override
def freeze_before_training(self, pl_module: "pl.LightningModule") -> None:
"""Freezes all model parameters so that parameter subsets can be subsequently thawed according to the fine-
tuning schedule.
Expand All @@ -300,7 +316,11 @@ def freeze_before_training(self, pl_module: "pl.LightningModule") -> None:
pl_module (:external+pl:class:`~lightning.pytorch.core.module.LightningModule`): The target
:external+pl:class:`~lightning.pytorch.core.module.LightningModule` to freeze parameters of
"""
self.freeze(modules=pl_module, train_bn=False)
# We avoid overriding `BaseFinetuning`'s `freeze` and `freeze_module` methods at the small marginal cost
# of conditionally revisiting `BatchNorm` layers to set `track_running_stats` to `True` when we are in
# `frozen_bn_track_running_stats` mode.
BaseFinetuning.freeze(modules=pl_module, train_bn=False)
self.strategy_adapter._module_specific_freezing(modules=pl_module)

def step(self) -> None:
"""Prepare and execute the next scheduled fine-tuning level
Expand Down Expand Up @@ -387,6 +407,7 @@ def step_pg(
else:
thaw_layers = {depth: self.ft_schedule[depth]}.items()
for i, orig_next_tl in thaw_layers:
self.strategy_adapter._maybe_set_bn_track_running_stats(i)
next_tl = deepcopy(orig_next_tl)
if i <= depth:
next_tl["params"] = self.strategy_adapter.fts_optim_transform(next_tl["params"])
Expand Down
21 changes: 20 additions & 1 deletion src/finetuning_scheduler/fts_supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from dataclasses import dataclass, field, fields
from functools import reduce
from pprint import pformat
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, Set
from typing_extensions import TypeAlias

import lightning.pytorch as pl
Expand Down Expand Up @@ -1271,6 +1271,7 @@ class ScheduleImplMixin(ABC):
reinit_optim_cfg: Optional[Dict]
reinit_lr_cfg: Optional[Dict]
max_depth: int
_msg_cache: Set
_fts_state: FTSState
PHASE_0_DIVERGENCE_MSG = (
"After executing the provided `configure_optimizers` method, the optimizer state differs from the configuration"
Expand Down Expand Up @@ -1479,6 +1480,7 @@ def thaw_to_depth(self, depth: Optional[int] = None) -> None:
depth = depth or self.curr_depth
for i, next_tl in self.ft_schedule.items(): # type: ignore[union-attr]
if i <= depth:
self.strategy_adapter._maybe_set_bn_track_running_stats(i)
_, self._fts_state._curr_thawed_params = self.strategy_adapter.exec_ft_phase(
self.pl_module, thaw_pl=self.strategy_adapter.fts_optim_transform(next_tl["params"])
)
Expand Down Expand Up @@ -1749,6 +1751,23 @@ def _validate_opt_init(self) -> None:
)
rank_zero_warn(w_msg)

def _conditional_warn_once(self, condition: Any, msg: str) -> None:
"""A helper function that conditionally issues a warning message only once based on the provided condition
variable. Robust to context managers that may prevent warnings.filterwarnings("once") from behaving as
expected.
Args:
condition (Any): The condition to evaluate for issuing the warning.
msg (str): The warning message to display.
Returns:
None
"""
if not bool(condition) or msg in self._msg_cache:
return
self._msg_cache.add(msg)
rank_zero_warn(msg)


class CallbackDepMixin(ABC):
"""Functionality for validating/managing callback dependencies."""
Expand Down
98 changes: 94 additions & 4 deletions src/finetuning_scheduler/strategy_adapters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@
"""
from functools import partialmethod
from pprint import pformat as pfmt
from typing import Callable, List, Optional, Tuple
from typing import Callable, List, Optional, Tuple, Dict

import torch

from lightning.fabric.utilities import rank_zero_info
from lightning.fabric.utilities.types import ReduceLROnPlateau
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.callbacks import BaseFinetuning
from lightning.pytorch.strategies.strategy import Strategy
from lightning.pytorch.utilities.rank_zero import rank_zero_debug
from torch.nn import Module


class StrategyAdapter:
Expand All @@ -51,6 +53,14 @@ class StrategyAdapter:
"""

fts_handle: Callback
_ft_schedule_module_map: Dict
_unscheduled_params: List

FROZEN_BN_DEFAULT_WARN = ( # TODO: remove warning with release of FTS 2.4.0
"Starting with the next minor release of FTS (2.4.0), the default value for `frozen_bn_track_running_stats`"
" will change to `True`. To retain the current `track_running_stats` `False` behavior with FTS >= 2.4.0, frozen"
" `BatchNorm` layers like those in this model will require setting `frozen_bn_track_running_stats` to `False`."
)

def __init__(self) -> None:
"""The default fine-tuning phase execution function is set on
Expand Down Expand Up @@ -106,6 +116,9 @@ def on_after_init_fts(self) -> None:
"""Hook executed in :class:`~finetuning_scheduler.fts.FinetuningScheduler` setup immediately after
:meth:`~finetuning_scheduler.fts_supporters.ScheduleImplMixin.init_fts`.
"""
self._gen_ft_sched_module_map()
self.scheduled_mod_lists = [list(self._ft_schedule_module_map[d]) for d in self._ft_schedule_module_map.keys()]
self._maybe_set_bn_track_running_stats(0)
_, self.fts_handle._fts_state._curr_thawed_params = self.exec_ft_phase(
self.pl_module,
thaw_pl=self.fts_optim_transform(self.fts_handle.ft_schedule[0]["params"]),
Expand Down Expand Up @@ -160,6 +173,26 @@ def logical_param_translation(self, param_names: List) -> List:
"""
return param_names

def _gen_ft_sched_module_map(self) -> None:
"""Generate a module-level mapping of the modules associated with each fine-tuning phase, including modules
not present in the fine-tuning schedule grouped together into a single unscheduled phase to facilitate the
relevant disjointness check."""
assert isinstance(self.fts_handle.ft_schedule, Dict)
module_map: Dict = {}
for depth in self.fts_handle.ft_schedule.keys(): # type: ignore[union-attr]
phase_params = self.fts_handle.ft_schedule[depth].get("params", []) # type: ignore[union-attr]
module_map[depth] = set()
for p in phase_params:
module_map[depth].add(p.rpartition(".")[0])
self._ft_schedule_module_map = module_map
scheduled_mods = list(set().union(*module_map.values()))
unscheduled_mods = tuple(
n for n, m in self.pl_module.named_modules() if n not in scheduled_mods and m._parameters
)
self._unscheduled_params = [
f"{m}.{n}" for m in unscheduled_mods for n, _ in self.pl_module.get_submodule(m).named_parameters()
]

@staticmethod
def _clean_optim_lr_pgs(trainer: Trainer) -> List:
"""Delete existing param groups from an optimizer that was found to be misaligned with respect to phase 0
Expand Down Expand Up @@ -246,8 +279,8 @@ def phase0_optimizer_override(self) -> None:

@staticmethod
def base_ft_phase(
module: Module, thaw_pl: List, translation_func: Optional[Callable] = None, init_thaw: bool = False
) -> Tuple[List, List]:
module: torch.nn.Module, thaw_pl: List, translation_func: Optional[Callable] = None, init_thaw: bool = False) \
-> Tuple[List, List]:
"""Thaw/unfreeze the provided list of parameters in the provided :class:`~torch.nn.Module`
Args:
Expand Down Expand Up @@ -281,4 +314,61 @@ def base_ft_phase(
)
return thawed_p_names, curr_thawed

####################################################################################################################
# BatchNorm module-specific handling
# (if additional modules require special handling, these will be refactored to accommodate a more generic
# dispatching pattern for module-specific handling)
####################################################################################################################

def _module_specific_freezing(self, modules: torch.nn.Module) -> None:
"""Orchestrates module-specific freezing behavior. Currently only.
:external+torch:class:`~torch.nn.modules.batchnorm._BatchNorm` layers require special handling. Running
statistics tracking for frozen `BatchNorm` layers is conditionally re-enabled here based on the
`frozen_bn_track_running_stats` flag.
Args:
modules (torch.nn.Module): The modules for which the `BatchNorm` layer running statistics should be enabled.
Returns:
None
"""
if self.fts_handle.frozen_bn_track_running_stats:
rank_zero_info("Since `frozen_bn_track_running_stats` is currently set to `True`, FinetuningScheduler"
" will set `track_running_stats` to `True` for all `BatchNorm` layers.")
modules = BaseFinetuning.flatten_modules(modules) # type: ignore[assignment]
for mod in modules:
if isinstance(mod, torch.nn.modules.batchnorm._BatchNorm):
mod.track_running_stats = True

def _maybe_set_bn_track_running_stats(self, schedule_phase: int) -> None:
"""Enable `track_running_stats` for :external+torch:class:`~torch.nn.modules.batchnorm._BatchNorm` modules
that may require it based on `frozen_bn_track_running_stats` and a given schedule phase.
Args:
schedule_phase (int): The phase of the schedule to evaluate.
Returns:
None
"""
if not self.fts_handle.frozen_bn_track_running_stats:
target_bn_modules = self._get_target_bn_modules(schedule_phase)
self.fts_handle._conditional_warn_once(target_bn_modules, self.FROZEN_BN_DEFAULT_WARN)
for _, m in target_bn_modules:
m.track_running_stats = True

def _get_target_bn_modules(self, schedule_phase: int) -> List:
"""Enumerate the :external+torch:class:`~torch.nn.modules.batchnorm._BatchNorm` modules for a given
schedule phase.
Args:
schedule_phase (int): The phase of the schedule to evaluate.
Returns:
List[Tuple[str, torch.nn.modules.batchnorm._BatchNorm]]: A list of tuples containing the names and instances
of `BatchNorm` modules associated with a given schedule phase.
"""
return [(n, m) for n, m in self.pl_module.named_modules() if
n in self.scheduled_mod_lists[schedule_phase] and
isinstance(m, torch.nn.modules.batchnorm._BatchNorm)]

fts_optim_inspect = partialmethod(fts_optim_transform, inspect_only=True)
Loading

0 comments on commit 9d7014b

Please sign in to comment.