Skip to content

Commit

Permalink
update type annotations/hints to accommodate mypy 1.2, bump PL dep de…
Browse files Browse the repository at this point in the history
…v sha
  • Loading branch information
speediedan committed May 1, 2023
1 parent 3c1a23e commit b9434f3
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 25 deletions.
3 changes: 1 addition & 2 deletions requirements/base.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#lightning>=2.0.0,<2.0.1
# the below is uncommented when master is targeting a specific pl dev master commit
git+https://github.com/Lightning-AI/lightning.git@e9d685635580e6150a0e65d9b4c7ee8ad2a1de71#egg=lightning
#git+https://github.com/Lightning-AI/lightning.git@1d1f6009630d01f5347a7234dad97f6c75f93af0#egg=lightning
git+https://github.com/Lightning-AI/lightning.git@7c6d42a6b9ab3f20ac5495676960b73e993f85c4#egg=lightning
torch>=1.11.0
3 changes: 1 addition & 2 deletions requirements/standalone_base.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#pytorch-lightning>=2.0.0,<2.0.1
# the below is uncommented when master is targeting a specific pl dev master commit
git+https://github.com/Lightning-AI/pytorch-lightning.git@e9d685635580e6150a0e65d9b4c7ee8ad2a1de71#egg=pytorch-lightning
#git+https://github.com/Lightning-AI/pytorch-lightning.git@1d1f6009630d01f5347a7234dad97f6c75f93af0#egg=pytorch-lightning
git+https://github.com/Lightning-AI/pytorch-lightning.git@7c6d42a6b9ab3f20ac5495676960b73e993f85c4#egg=pytorch-lightning
torch>=1.11.0
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,7 @@ def _setup_args(standalone: bool = False) -> Dict[str, Any]:
_INSTALL_PATHS["require"],
file_name=base_reqs,
standalone=standalone,
pl_commit="e9d685635580e6150a0e65d9b4c7ee8ad2a1de71",
# pl_commit="1d1f6009630d01f5347a7234dad97f6c75f93af0",
pl_commit="7c6d42a6b9ab3f20ac5495676960b73e993f85c4",
)
base_setup["install_requires"] = install_requires
return base_setup
Expand Down
13 changes: 8 additions & 5 deletions src/finetuning_scheduler/fts.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,13 +320,13 @@ def step(self) -> None:
self.restore_best_ckpt()
self.step_pg(
depth=self.curr_depth,
optimizer=self.pl_module.trainer.optimizers[0],
optimizer=self.pl_module.trainer.optimizers[0], # type: ignore[arg-type]
pre_reinit_state=pre_reinit_state,
)
else:
self.step_pg(
depth=self.curr_depth,
optimizer=self.pl_module.trainer.optimizers[0],
optimizer=self.pl_module.trainer.optimizers[0], # type: ignore[arg-type]
depth_sync=False,
pre_reinit_state=pre_reinit_state,
)
Expand Down Expand Up @@ -474,7 +474,7 @@ def restore_best_ckpt(self) -> None:
self._fts_state._fts_ckpt_metadata["best_ckpt_pgs"][opt_idx], dict(self.pl_module.named_parameters())
)
if self.strategy_adapter.using_sharded_optimizer:
ScheduleImplMixin._repartition_sharded_optim(optimizer)
ScheduleImplMixin._repartition_sharded_optim(optimizer) # type: ignore[arg-type]
# we're restoring everything but callbacks and loops, otherwise, checkpoint_connector.restore() could be used
assert self.pl_module.trainer.checkpoint_callback is not None
checkpoint_path = self.pl_module.trainer.checkpoint_callback.best_model_path # type: ignore[attr-defined]
Expand Down Expand Up @@ -530,7 +530,7 @@ def _reduce_transition(self, strategy: Strategy, decision: bool) -> bool:
bool: The reduced decision across all world processes.
"""
decision = torch.tensor(int(decision), device=strategy.root_device)
decision = bool(strategy.reduce(decision, reduce_op=ReduceOp.SUM))
decision = bool(strategy.reduce(decision, reduce_op=ReduceOp.SUM)) # type:ignore[arg-type]
return decision

def _sync_es_state(self, trainer: "pl.Trainer") -> None:
Expand Down Expand Up @@ -798,7 +798,10 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
self._store(pl_module, opt_idx, num_saved_groups, current_param_groups)

def on_before_zero_grad(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: ParamGroupAddable
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
optimizer: ParamGroupAddable, # type: ignore[override]
) -> None:
"""Afer the latest optimizer step, update the
:attr:`~finetuning_scheduler.fts.FinetuningScheduler._fts_state`, incrementing the
Expand Down
34 changes: 23 additions & 11 deletions src/finetuning_scheduler/fts_supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,7 +942,7 @@ def _save_pre_reinit_lr_state(self, trainer: pl.Trainer) -> Tuple[Dict, List]:
Tuple[Dict, List]: The lr state to restore from the current lr scheduler and the most recent `lr`s for
parameter groups associated with the current phases's optimizer.
"""
curr_lr_state = {}
curr_lr_state: Dict = {}
if trainer.lr_scheduler_configs:
curr_lr_state = deepcopy(trainer.lr_scheduler_configs[0].scheduler.state_dict())
prev_optimizer_lrs = copy([group["lr"] for group in trainer.strategy.optimizers[0].param_groups])
Expand All @@ -963,18 +963,20 @@ def reinit_optimizer(self, new_optimizer: Dict, trainer: pl.Trainer, init_params
prev_optim_repr = repr(trainer.strategy.optimizers[0])
optimizer_class = self._import_reinit_class(optimizer_init, reinit_target="optimizer")
reinit_pgs = self._reinit_phase0_pgs(thawed_pl=init_params)
new_optimizer_handle = optimizer_class(reinit_pgs, **optimizer_init.get("init_args", {}))
new_optimizer_handle = optimizer_class(
reinit_pgs, **optimizer_init.get("init_args", {}) # type: ignore[operator, arg-type]
)
# If the user or optimizer doesn't set `initial_lr` keys, add them based on the initial lr values.
# The latest LR state will still be set in subsequent phases, but this allows subsequent lr scheduler
# reinitializations to access an `initial_lr` for the existing optimizer if desired (important for consistency
# with lr scheduler-only reinitializations).
for group in new_optimizer_handle.param_groups:
for group in new_optimizer_handle.param_groups: # type: ignore[union-attr]
group["initial_lr"] = group.get("initial_lr", group["lr"])
trainer.strategy.optimizers = [new_optimizer_handle]
trainer.strategy.optimizers = [new_optimizer_handle] # type: ignore[list-item]
if trainer.lr_scheduler_configs:
trainer.lr_scheduler_configs[0].scheduler.optimizer = new_optimizer_handle
self._maybe_trace_reinit("optimizer", prev_optim_repr, repr(trainer.strategy.optimizers[0]))
return new_optimizer_handle
return new_optimizer_handle # type:ignore[return-value]

def reinit_lr_scheduler(self, new_lr_scheduler: Dict, trainer: pl.Trainer, optimizer: ParamGroupAddable) -> None:
"""Reinitialize the learning rate scheduler, using a validated learning rate scheduler configuration and
Expand Down Expand Up @@ -1008,11 +1010,14 @@ def reinit_lr_scheduler(self, new_lr_scheduler: Dict, trainer: pl.Trainer, optim
if reset_init_pg_lrs:
param_group["initial_lr"] = lr
if "pl_lrs_cfg" in new_lr_scheduler.keys():
new_lr_scheduler["pl_lrs_cfg"] = self._update_pl_lrs(new_lr_scheduler["pl_lrs_cfg"], lrs_class=lrs_class)
new_lr_scheduler["pl_lrs_cfg"] = self._update_pl_lrs(
new_lr_scheduler["pl_lrs_cfg"], lrs_class=lrs_class
) # type:ignore[arg-type]
assert callable(lrs_class)
new_lrs_config = LRSchedulerConfig(
scheduler=lrs_class(
optimizer=optimizer, **lr_scheduler_init.get("init_args", {})
), # type: ignore[arg-type]
optimizer=optimizer, **lr_scheduler_init.get("init_args", {}) # type: ignore[arg-type]
),
**new_lr_scheduler.get("pl_lrs_cfg", {}),
)
trainer.strategy.lr_scheduler_configs = [new_lrs_config]
Expand Down Expand Up @@ -1091,7 +1096,7 @@ def _is_supported_reinit_optimizer(self, optim_class: Union[Any, ParamGroupAddab
MisconfigurationException: If the provided optimizer class is known to be currently unsupported in the
context of optimizer reinitialization.
"""
if issubclass(optim_class, ZeroRedundancyOptimizer):
if issubclass(optim_class, ZeroRedundancyOptimizer): # type: ignore[arg-type]
error_msg = (
f"The provided optimizer ({optim_class}) is not currently supported by FinetuningScheduler in the"
" context of optimizer reinitialization. Please use a currently supported torch optimizer (or subclass"
Expand All @@ -1116,6 +1121,7 @@ def _import_reinit_class(
Returns:
Union[FTSLRSchedulerType, ParamGroupAddable]: The class to reinitialize.
"""
# TODO: refactor this function to enable type narrowing while continuing to share relevant code paths
try:
class_module, class_name = reinit_cfg["class_path"].rsplit(".", 1)
module = __import__(class_module, fromlist=[class_name])
Expand Down Expand Up @@ -1184,7 +1190,10 @@ def _optimizer_sanity_chk(self, optimizer_init: Dict) -> None:
self._is_supported_reinit_optimizer(optimizer_class)
test_optimizer_init = copy(optimizer_init.get("init_args", {}))
try:
test_optimizer = optimizer_class(ScheduleParsingMixin.SANITY_CHK_ITERABLE, **test_optimizer_init)
assert callable(optimizer_class)
test_optimizer = optimizer_class(
ScheduleParsingMixin.SANITY_CHK_ITERABLE, **test_optimizer_init # type: ignore[arg-type]
)
except Exception as err:
error_msg = (
"Could not configure the specified optimizer class using the `init_args` "
Expand Down Expand Up @@ -1222,7 +1231,9 @@ def _lr_scheduler_sanity_chk(self, lr_scheduler_init: Dict, is_implicit_mode: bo
invalid_min_lr = (
True if min_lr_param and (isinstance(min_lr_param, list) or isinstance(min_lr_param, tuple)) else False
)
reinit_rlrop = is_implicit_mode and issubclass(lrs_class, torch.optim.lr_scheduler.ReduceLROnPlateau)
reinit_rlrop = is_implicit_mode and issubclass(
lrs_class, torch.optim.lr_scheduler.ReduceLROnPlateau # type: ignore[arg-type]
)
if reinit_rlrop and invalid_min_lr:
raise MisconfigurationException(
"In the lr scheduler configuration passed via `reinit_lr_cfg` (i.e. implicit mode training)"
Expand All @@ -1233,6 +1244,7 @@ def _lr_scheduler_sanity_chk(self, lr_scheduler_init: Dict, is_implicit_mode: bo
if min_lr_param:
del test_lr_init["min_lr"] # our mock optimizer will not have any param groups
try:
assert callable(lrs_class)
testlr = lrs_class(optimizer=_MockOptimizer(), **test_lr_init)
except Exception as err:
error_msg = (
Expand Down
6 changes: 4 additions & 2 deletions src/finetuning_scheduler/strategy_adapters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def _clean_optim_lr_pgs(trainer: Trainer) -> List:
orig_num_pgs.append(len(optimizer.param_groups))
optimizer.param_groups = []
for lrs_cfg in trainer.lr_scheduler_configs:
lrs_cfg.scheduler.last_epoch = -1
lrs_cfg.scheduler.last_epoch = -1 # type: ignore[union-attr]
if not isinstance(lrs_cfg.scheduler, ReduceLROnPlateau):
lrs_cfg.scheduler.base_lrs = []
return orig_num_pgs
Expand Down Expand Up @@ -215,7 +215,9 @@ def _reconfigure_lrs_for_phase0(self, trainer: Trainer, orig_num_pgs: List) -> N
for lrs_cfg in trainer.lr_scheduler_configs:
if _TORCH_GREATER_EQUAL_1_13 and not isinstance(lrs_cfg.scheduler, ReduceLROnPlateau):
lrs_cfg.scheduler._initial_step()
lrs_cfg.scheduler._last_lr = [group["lr"] for group in lrs_cfg.scheduler.optimizer.param_groups]
lrs_cfg.scheduler._last_lr = [
group["lr"] for group in lrs_cfg.scheduler.optimizer.param_groups # type: ignore[union-attr]
]
if isinstance(lrs_cfg.scheduler, ReduceLROnPlateau):
lrs_cfg.scheduler.min_lrs = lrs_cfg.scheduler.min_lrs[orig_num_pgs[0] :]
elif hasattr(lrs_cfg.scheduler, "lr_lambdas"):
Expand Down
5 changes: 4 additions & 1 deletion src/finetuning_scheduler/strategy_adapters/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ def _validate_fsdp_phases_disjoint(self) -> Tuple:
"Bypassing FSDP-specific phase disjointness validation because `use_orig_params` is "
"``True`` and PyTorch is >= `2.1.0`"
)
assert self.pl_module._trainer is not None
# check only required for mixed-precision training with DEBUG level logging requested
if self.pl_module._trainer.precision in ("16-mixed", "bf16-mixed") and self._rank_zero_logger.level <= 10:
has_no_local_shards = self._log_nonzero_local_shards()
Expand Down Expand Up @@ -716,7 +717,9 @@ def _fts_auto_wrap(self) -> None:

# apply wrappers to enable activation checkpointing if requested
if self.pls_handle._activation_checkpointing:
_setup_activation_checkpointing(module=self.pl_module, layers=self.pls_handle._activation_checkpointing)
_setup_activation_checkpointing(
module=self.pl_module, layers=self.pls_handle._activation_checkpointing # type: ignore[arg-type]
)

def _after_configure_sharded_model(self) -> None:
"""Generate the parameter-level bi-directional translations the FTS FSDP adapter requires and then execute
Expand Down

0 comments on commit b9434f3

Please sign in to comment.