Skip to content

Commit

Permalink
refactor AveragedModel utils (#651)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #651

Refactors AveragedModel in TNT utils to be more streamlined

Reviewed By: galrotem

Differential Revision: D52093437

fbshipit-source-id: e2b8faae679688f89a116d93398b8cd92e4dd75b
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Dec 13, 2023
1 parent d995da2 commit c32f68b
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 63 deletions.
10 changes: 3 additions & 7 deletions tests/framework/test_auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,7 @@ def test_stochastic_weight_averaging_update_freq(self) -> None:
swalr_params=SWALRParams(
anneal_steps_or_epochs=5,
),
# pyre-ignore: Undefined attribute [16]: Module
multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(),
averaging_method="ema",
)
auto_unit = DummyAutoUnit(
module=my_module,
Expand Down Expand Up @@ -365,8 +364,7 @@ def forward(self, x):
swalr_params=SWALRParams(
anneal_steps_or_epochs=3,
),
# pyre-ignore: Undefined attribute [16]
multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(),
averaging_method="ema",
),
)

Expand All @@ -381,8 +379,7 @@ def forward(self, x):
swalr_params=SWALRParams(
anneal_steps_or_epochs=3,
),
# pyre-ignore: Undefined attribute [16]
multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(),
averaging_method="ema",
),
)

Expand Down Expand Up @@ -619,7 +616,6 @@ def forward(self, x):
swalr_params=SWALRParams(
anneal_steps_or_epochs=3,
),
avg_fn=lambda x, y, z: x,
)

auto_unit = DummyAutoUnit(
Expand Down
70 changes: 32 additions & 38 deletions tests/utils/test_swa.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@

import torch

# TODO: torch/optim/swa_utils.pyi needs to be updated
# pyre-ignore: Undefined import [21]
from torch.optim.swa_utils import get_ema_multi_avg_fn, get_swa_multi_avg_fn

from torchtnt.utils.swa import AveragedModel


Expand Down Expand Up @@ -47,18 +43,15 @@ def _run_averaged_steps(
self, dnn: torch.nn.Module, swa_device: torch.device, ema: bool
) -> Tuple[List[torch.Tensor], torch.nn.Module]:
ema_decay = 0.999
multi_avg_fn = (
# pyre-ignore: Undefined attribute [16]
get_ema_multi_avg_fn(ema_decay)
if ema
# pyre-ignore: Undefined attribute [16]
else get_swa_multi_avg_fn()
)
averaged_dnn = AveragedModel(
dnn,
device=swa_device,
multi_avg_fn=multi_avg_fn,
)
if ema:
averaged_dnn = AveragedModel(
dnn,
device=swa_device,
averaging_method="ema",
ema_decay=ema_decay,
)
else:
averaged_dnn = AveragedModel(dnn, device=swa_device, averaging_method="swa")

averaged_params = [torch.zeros_like(param) for param in dnn.parameters()]

Expand Down Expand Up @@ -105,14 +98,12 @@ def test_averaged_model_state_dict(self) -> None:
self.assertTrue(averaged_dnn.n_averaged == averaged_dnn2.n_averaged)

def test_averaged_model_exponential(self) -> None:
combos = itertools.product([True, False], [True, False], [True, False])
for use_multi_avg_fn, use_buffers, skip_deepcopy in combos:
self._test_averaged_model_exponential(
use_multi_avg_fn, use_buffers, skip_deepcopy
)
combos = itertools.product([True, False], [True, False])
for use_buffers, skip_deepcopy in combos:
self._test_averaged_model_exponential(use_buffers, skip_deepcopy)

def _test_averaged_model_exponential(
self, use_multi_avg_fn: bool, use_buffers: bool, skip_deepcopy: bool
self, use_buffers: bool, skip_deepcopy: bool
) -> None:
# Test AveragedModel with EMA as avg_fn and use_buffers as True.
dnn = torch.nn.Sequential(
Expand All @@ -122,22 +113,13 @@ def _test_averaged_model_exponential(
)
decay: float = 0.9

if use_multi_avg_fn:
averaged_dnn = AveragedModel(
deepcopy(dnn) if skip_deepcopy else dnn,
# pyre-ignore Undefined attribute [16]
multi_avg_fn=get_ema_multi_avg_fn(decay),
use_buffers=use_buffers,
skip_deepcopy=skip_deepcopy,
)
else:

def avg_fn(
p_avg: torch.Tensor, p: torch.Tensor, n_avg: float
) -> torch.Tensor:
return decay * p_avg + (1 - decay) * p

averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn, use_buffers=use_buffers)
averaged_dnn = AveragedModel(
deepcopy(dnn) if skip_deepcopy else dnn,
averaging_method="ema",
ema_decay=decay,
use_buffers=use_buffers,
skip_deepcopy=skip_deepcopy,
)

if use_buffers:
dnn_params = list(itertools.chain(dnn.parameters(), dnn.buffers()))
Expand Down Expand Up @@ -199,3 +181,15 @@ def test_averaged_model_skip_deepcopy(self) -> None:

averaged_dnn2 = AveragedModel(dnn, device)
self.assertNotEqual(id(dnn), id(averaged_dnn2.module))

def test_input_checks(self) -> None:
model = torch.nn.Linear(2, 2)

with self.assertRaisesRegex(ValueError, "Decay must be between 0 and 1"):
AveragedModel(model, averaging_method="ema", ema_decay=1.3)

with self.assertRaisesRegex(
ValueError, "Unknown averaging method: foo. Only ema and swa are supported."
):
# pyre-ignore On purpose to test run time exception
AveragedModel(model, averaging_method="foo")
13 changes: 7 additions & 6 deletions torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,9 @@ class SWAParams:
Args:
warmup_steps_or_epochs: number of steps or epochs before starting SWA
step_or_epoch_update_freq: number of steps or epochs between each SWA update
avg_fn: function to compute custom average of parameters
multi_avg_fn: function used to update parameters inplace
averaging_method: whether to use SWA or EMA to average model weights
ema_decay: The exponential decay applied to the averaged parameters. This param
is only needed for EMA, and is ignored otherwise (for SWA).
swalr_params: params for SWA learning rate scheduler
Note: Whether steps or epochs is used based on what `step_lr_interval` is set on the AutoUnit.
Expand All @@ -96,8 +97,8 @@ class SWAParams:

warmup_steps_or_epochs: int
step_or_epoch_update_freq: int
avg_fn: Optional[TSWA_avg_fn] = None
multi_avg_fn: Optional[TSWA_multi_avg_fn] = None
averaging_method: Literal["ema", "swa"] = "ema"
ema_decay: float = 0.999
swalr_params: Optional[SWALRParams] = None


Expand Down Expand Up @@ -476,9 +477,9 @@ def __init__(
self.swa_model = AveragedModel(
module_for_swa,
device=device,
avg_fn=swa_params.avg_fn,
multi_avg_fn=swa_params.multi_avg_fn,
use_buffers=True,
averaging_method=swa_params.averaging_method,
ema_decay=swa_params.ema_decay,
skip_deepcopy=skip_deepcopy,
)

Expand Down
46 changes: 34 additions & 12 deletions torchtnt/utils/swa.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable, List, Optional
from typing import Callable, List, Literal, Optional

import torch

from torch.optim.swa_utils import AveragedModel as PyTorchAveragedModel

# TODO: torch/optim/swa_utils.pyi needs to be updated
# pyre-ignore Undefined import [21]: Could not find a name `get_ema_multi_avg_fn` defined in module `torch.optim.swa_utils`.
from torch.optim.swa_utils import (
AveragedModel as PyTorchAveragedModel,
get_ema_multi_avg_fn,
get_swa_multi_avg_fn,
)

TSWA_avg_fn = Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor]
TSWA_multi_avg_fn = Callable[[List[torch.Tensor], List[torch.Tensor], int], None]
Expand All @@ -20,34 +25,52 @@ def __init__(
self,
model: torch.nn.Module,
device: Optional[torch.device] = None,
avg_fn: Optional[TSWA_avg_fn] = None,
multi_avg_fn: Optional[TSWA_multi_avg_fn] = None,
use_buffers: bool = False,
averaging_method: Literal["ema", "swa"] = "ema",
ema_decay: float = 0.999,
skip_deepcopy: bool = False,
) -> None:
"""
This class is a custom version of AveragedModel that allows us to skip the
automatic deepcopy step. This gives flexibility to use modules that are not
automatic deepcopy step and streamline the use of EMA / SWA. The deepcopy
optionality gives flexibility to use modules that are not
compatible with deepcopy, like FSDP wrapped modules. Check out
https://github.com/pytorch/pytorch/blob/main/torch/optim/swa_utils.py#L66
to see what the arguments entail.
to see what the model, device, and use_buffer arguments entail.
Args:
averaging_method: Whether to use EMA or SWA.
ema_decay: The exponential decay applied to the averaged parameters. This param
is only needed for EMA, and is ignored otherwise (for SWA).
skip_deepcopy: If True, will skip the deepcopy step. The user must ensure
that the module passed in is already copied in someway
"""
# setup averaging method
if averaging_method == "ema":
if ema_decay < 0.0 or ema_decay > 1.0:
raise ValueError(f"Decay must be between 0 and 1, got {ema_decay}")

# TODO: torch/optim/swa_utils.pyi needs to be updated
# pyre-ignore Undefined attribute [16]: Module `torch.optim.swa_utils` has no attribute `get_ema_multi_avg_fn`.
multi_avg_fn = get_ema_multi_avg_fn(ema_decay)
elif averaging_method == "swa":
# TODO: torch/optim/swa_utils.pyi needs to be updated
# pyre-ignore Undefined attribute [16]: Module `torch.optim.swa_utils` has no attribute `get_swa_multi_avg_fn`.
multi_avg_fn = get_swa_multi_avg_fn()
else:
raise ValueError(
f"Unknown averaging method: {averaging_method}. Only ema and swa are supported."
)

if skip_deepcopy:
# calls parent init manually, but skips deepcopy step
torch.nn.Module.__init__(self) # inits grandparent class

assert (
avg_fn is None or multi_avg_fn is None
), "Only one of avg_fn and multi_avg_fn should be provided"
self.module: torch.nn.Module = model
self.register_buffer(
"n_averaged", torch.tensor(0, dtype=torch.long, device=device)
)
self.avg_fn: Optional[TSWA_avg_fn] = avg_fn
self.avg_fn: Optional[TSWA_avg_fn] = None
self.multi_avg_fn: Optional[TSWA_multi_avg_fn] = multi_avg_fn
self.use_buffers: bool = use_buffers
else:
Expand All @@ -58,7 +81,6 @@ def __init__(
super().__init__(
model,
device=device,
avg_fn=avg_fn,
multi_avg_fn=multi_avg_fn,
use_buffers=use_buffers,
)

0 comments on commit c32f68b

Please sign in to comment.