Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/unit_tests/test_train_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def fake_build_optimizers(
}
return OptimizersContainer(
model_parts=model_parts,
optimizer_cls=torch.optim.Adam,
optimizer_kwargs=optimizer_kwargs,
name="Adam",
)


Expand Down
44 changes: 24 additions & 20 deletions torchtitan/components/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import copy
import functools
from typing import Any, Callable, Dict, Iterable, List
from typing import Any, Callable, Dict, Generic, List, TypeVar

import torch
import torch.nn as nn
Expand All @@ -30,18 +30,10 @@
]


def _create_optimizer(
parameters: Iterable[nn.Parameter], optimizer_kwargs: Dict[str, Any], name: str
) -> Optimizer:
if name == "Adam":
return torch.optim.Adam(parameters, **optimizer_kwargs)
elif name == "AdamW":
return torch.optim.AdamW(parameters, **optimizer_kwargs)
else:
raise NotImplementedError(f"Optimizer {name} not added.")
T = TypeVar("T", bound=Optimizer)


class OptimizersContainer(Optimizer):
class OptimizersContainer(Optimizer, Generic[T]):
"""A container for multiple optimizers.

This class is used to wrap multiple optimizers into a single object that can be
Expand All @@ -67,18 +59,21 @@ class OptimizersContainer(Optimizer):
name (str): Name of the optimizers.
"""

optimizers: List[Optimizer]
optimizers: List[T]
model_parts: List[nn.Module]

def __init__(
self, model_parts: List[nn.Module], optimizer_kwargs: Dict[str, Any], name: str
self,
model_parts: List[nn.Module],
optimizer_cls: type[T],
optimizer_kwargs: Dict[str, Any],
) -> None:
all_params = []
self.optimizers: List[Optimizer] = []
self.optimizers: List[T] = []
self.model_parts = model_parts
for model in self.model_parts:
params = [p for p in model.parameters() if p.requires_grad]
self.optimizers.append(_create_optimizer(params, optimizer_kwargs, name))
self.optimizers.append(optimizer_cls(params, **optimizer_kwargs))
all_params.extend(params)
self._validate_length(len(self.model_parts))
self._post_init(all_params, optimizer_kwargs)
Expand Down Expand Up @@ -139,7 +134,10 @@ class OptimizersInBackwardContainer(OptimizersContainer):
"""

def __init__(
self, model_parts: List[nn.Module], optimizer_kwargs: Dict[str, Any], name: str
self,
model_parts: List[nn.Module],
optimizer_cls: type[T],
optimizer_kwargs: Dict[str, Any],
) -> None:
all_params = []
self.model_parts = model_parts
Expand All @@ -148,7 +146,7 @@ def __init__(
for model in self.model_parts:
for p in model.parameters():
if p.requires_grad:
optim_dict[p] = _create_optimizer([p], optimizer_kwargs, name)
optim_dict[p] = optimizer_cls([p], **optimizer_kwargs)
all_params.append(p)

def optim_hook(param) -> None:
Expand Down Expand Up @@ -218,11 +216,17 @@ def build_optimizers(
"fused": fused,
"foreach": foreach,
}

optimizer_classes = {
"Adam": torch.optim.Adam,
"AdamW": torch.optim.AdamW,
}
if name not in optimizer_classes:
raise NotImplementedError(f"Optimizer {name} not added.")
optimizer_cls = optimizer_classes[name]
return (
OptimizersContainer(model_parts, optimizer_kwargs, name)
OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs)
if not optim_in_bwd
else OptimizersInBackwardContainer(model_parts, optimizer_kwargs, name)
else OptimizersInBackwardContainer(model_parts, optimizer_cls, optimizer_kwargs)
)


Expand Down