Skip to content

Commit

Permalink
add Lit EMA support AveragedModel (#652)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #652

This diff adds Lit EMA support directly in TNT's `AveragedModel`, which can be enabled through the `use_lit` flag (default false)

Appropriately forwards these args in AutoUnit as well

Reviewed By: galrotem

Differential Revision: D52103639

fbshipit-source-id: 27cec2e5b77e4085e12ca9946db446facba54918
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Dec 14, 2023
1 parent c32f68b commit 37b8a7a
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 1 deletion.
31 changes: 31 additions & 0 deletions tests/utils/test_swa.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,34 @@ def test_input_checks(self) -> None:
):
# pyre-ignore On purpose to test run time exception
AveragedModel(model, averaging_method="foo")

def test_lit_ema(self) -> None:
model = torch.nn.Linear(10, 10)
ema_decay = 0.999
averaged_model = AveragedModel(
model,
averaging_method="ema",
ema_decay=ema_decay,
use_lit=True,
)

averaged_params = [torch.zeros_like(param) for param in model.parameters()]

n_updates = 10
for i in range(n_updates):
decay = min(ema_decay, (1 + i + 1) / (10 + i + 1))

updated_averaged_params = []
for p, p_avg in zip(model.parameters(), averaged_params):
p.detach().add_(torch.randn_like(p))
if i == 0:
updated_averaged_params.append(p.clone())
else:
updated_averaged_params.append(
(p_avg * decay + p * (1 - decay)).clone()
)
averaged_model.update_parameters(model)
averaged_params = updated_averaged_params

for p_avg, p_swa in zip(averaged_params, averaged_model.parameters()):
torch.testing.assert_close(p_avg, p_swa, check_device=False)
8 changes: 7 additions & 1 deletion torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,12 @@ class SWAParams:
warmup_steps_or_epochs: number of steps or epochs before starting SWA
step_or_epoch_update_freq: number of steps or epochs between each SWA update
averaging_method: whether to use SWA or EMA to average model weights
ema_decay: The exponential decay applied to the averaged parameters. This param
ema_decay: the exponential decay applied to the averaged parameters. This param
is only needed for EMA, and is ignored otherwise (for SWA).
use_lit: if True, will use Lit EMA style by adjusting weight decay based on the
number of updates. The EMA decay will start small and will approach the
specified ema_decay as more updates occur. The ``averaging_method`` must be
set to ema.
swalr_params: params for SWA learning rate scheduler
Note: Whether steps or epochs is used based on what `step_lr_interval` is set on the AutoUnit.
Expand All @@ -99,6 +103,7 @@ class SWAParams:
step_or_epoch_update_freq: int
averaging_method: Literal["ema", "swa"] = "ema"
ema_decay: float = 0.999
use_lit: bool = False
swalr_params: Optional[SWALRParams] = None


Expand Down Expand Up @@ -481,6 +486,7 @@ def __init__(
averaging_method=swa_params.averaging_method,
ema_decay=swa_params.ema_decay,
skip_deepcopy=skip_deepcopy,
use_lit=swa_params.use_lit,
)

self.module: torch.nn.Module = prepare_module(
Expand Down
23 changes: 23 additions & 0 deletions torchtnt/utils/swa.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
averaging_method: Literal["ema", "swa"] = "ema",
ema_decay: float = 0.999,
skip_deepcopy: bool = False,
use_lit: bool = False,
) -> None:
"""
This class is a custom version of AveragedModel that allows us to skip the
Expand All @@ -44,6 +45,9 @@ def __init__(
is only needed for EMA, and is ignored otherwise (for SWA).
skip_deepcopy: If True, will skip the deepcopy step. The user must ensure
that the module passed in is already copied in someway
use_lit: If True, will use Lit EMA style by adjusting weight decay based on the
number of updates. The EMA decay will start small and will approach the
specified ema_decay as more updates occur.
"""
# setup averaging method
if averaging_method == "ema":
Expand All @@ -57,11 +61,18 @@ def __init__(
# TODO: torch/optim/swa_utils.pyi needs to be updated
# pyre-ignore Undefined attribute [16]: Module `torch.optim.swa_utils` has no attribute `get_swa_multi_avg_fn`.
multi_avg_fn = get_swa_multi_avg_fn()

if use_lit:
raise ValueError("LitEMA is only supported for EMA.")
else:
raise ValueError(
f"Unknown averaging method: {averaging_method}. Only ema and swa are supported."
)

self._ema_decay = ema_decay
self._use_lit = use_lit
self._num_updates = 0

if skip_deepcopy:
# calls parent init manually, but skips deepcopy step
torch.nn.Module.__init__(self) # inits grandparent class
Expand All @@ -84,3 +95,15 @@ def __init__(
multi_avg_fn=multi_avg_fn,
use_buffers=use_buffers,
)

def update_parameters(self, model: torch.nn.Module) -> None:
self._num_updates += 1
if self._use_lit:
decay = min(
self._ema_decay, (1 + self._num_updates) / (10 + self._num_updates)
)

# TODO: torch/optim/swa_utils.pyi needs to be updated
# pyre-ignore Undefined attribute [16]: Module `torch.optim.swa_utils` has no attribute `get_ema_multi_avg_fn`.
self.multi_avg_fn = get_ema_multi_avg_fn(decay)
super().update_parameters(model)

0 comments on commit 37b8a7a

Please sign in to comment.