From 99cc449cdd4bb7d5a2dc0145f786c36f0e0eb087 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Mon, 24 Feb 2025 16:13:41 +0000 Subject: [PATCH] Generalize Optimizers container type, by passing base internal optimizer class. Passing `optimizer_cls` to `OptimizersContainer` and `OptimizersInBackwardContainer` constructors, instead of `name`. --- tests/unit_tests/test_train_spec.py | 2 +- torchtitan/components/optimizer.py | 44 ++++++++++++++++------------- 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/tests/unit_tests/test_train_spec.py b/tests/unit_tests/test_train_spec.py index 678684258e..dadd67de54 100644 --- a/tests/unit_tests/test_train_spec.py +++ b/tests/unit_tests/test_train_spec.py @@ -47,8 +47,8 @@ def fake_build_optimizers( } return OptimizersContainer( model_parts=model_parts, + optimizer_cls=torch.optim.Adam, optimizer_kwargs=optimizer_kwargs, - name="Adam", ) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 277d1b75fe..2a4b52f0e4 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -6,7 +6,7 @@ import copy import functools -from typing import Any, Callable, Dict, Iterable, List +from typing import Any, Callable, Dict, Generic, List, TypeVar import torch import torch.nn as nn @@ -30,18 +30,10 @@ ] -def _create_optimizer( - parameters: Iterable[nn.Parameter], optimizer_kwargs: Dict[str, Any], name: str -) -> Optimizer: - if name == "Adam": - return torch.optim.Adam(parameters, **optimizer_kwargs) - elif name == "AdamW": - return torch.optim.AdamW(parameters, **optimizer_kwargs) - else: - raise NotImplementedError(f"Optimizer {name} not added.") +T = TypeVar("T", bound=Optimizer) -class OptimizersContainer(Optimizer): +class OptimizersContainer(Optimizer, Generic[T]): """A container for multiple optimizers. This class is used to wrap multiple optimizers into a single object that can be @@ -67,18 +59,21 @@ class OptimizersContainer(Optimizer): name (str): Name of the optimizers. """ - optimizers: List[Optimizer] + optimizers: List[T] model_parts: List[nn.Module] def __init__( - self, model_parts: List[nn.Module], optimizer_kwargs: Dict[str, Any], name: str + self, + model_parts: List[nn.Module], + optimizer_cls: type[T], + optimizer_kwargs: Dict[str, Any], ) -> None: all_params = [] - self.optimizers: List[Optimizer] = [] + self.optimizers: List[T] = [] self.model_parts = model_parts for model in self.model_parts: params = [p for p in model.parameters() if p.requires_grad] - self.optimizers.append(_create_optimizer(params, optimizer_kwargs, name)) + self.optimizers.append(optimizer_cls(params, **optimizer_kwargs)) all_params.extend(params) self._validate_length(len(self.model_parts)) self._post_init(all_params, optimizer_kwargs) @@ -139,7 +134,10 @@ class OptimizersInBackwardContainer(OptimizersContainer): """ def __init__( - self, model_parts: List[nn.Module], optimizer_kwargs: Dict[str, Any], name: str + self, + model_parts: List[nn.Module], + optimizer_cls: type[T], + optimizer_kwargs: Dict[str, Any], ) -> None: all_params = [] self.model_parts = model_parts @@ -148,7 +146,7 @@ def __init__( for model in self.model_parts: for p in model.parameters(): if p.requires_grad: - optim_dict[p] = _create_optimizer([p], optimizer_kwargs, name) + optim_dict[p] = optimizer_cls([p], **optimizer_kwargs) all_params.append(p) def optim_hook(param) -> None: @@ -218,11 +216,17 @@ def build_optimizers( "fused": fused, "foreach": foreach, } - + optimizer_classes = { + "Adam": torch.optim.Adam, + "AdamW": torch.optim.AdamW, + } + if name not in optimizer_classes: + raise NotImplementedError(f"Optimizer {name} not added.") + optimizer_cls = optimizer_classes[name] return ( - OptimizersContainer(model_parts, optimizer_kwargs, name) + OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs) if not optim_in_bwd - else OptimizersInBackwardContainer(model_parts, optimizer_kwargs, name) + else OptimizersInBackwardContainer(model_parts, optimizer_cls, optimizer_kwargs) )