Skip to content

Commit

Permalink
turn on foreach
Browse files Browse the repository at this point in the history
  • Loading branch information
wz337 authored and pytorchmergebot committed May 10, 2024
1 parent d7fe3c4 commit 240236d
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 50 deletions.
121 changes: 73 additions & 48 deletions test/distributed/_tensor/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,23 +84,26 @@ def _assert_optimizer(
# Default 'rtol' and 'atol' for attr:`~torch.float32` are ``1.3e-6`` and ``1e-5``
self.assertEqual(p1, p2, atol=atol, rtol=rtol)

def test_optimizer_foreach_supported_types_include_DTensor(self):
from torch.optim.optimizer import _foreach_supported_types

self.assertTrue(DTensor in _foreach_supported_types)

@with_comms
def test_adam_1d_sharding(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))

# TODO: add fused_adam support
adam_configs = [
{"lr": 0.1},
{"lr": 0.1, "foreach": False},
{"lr": 0.1, "weight_decay": 0.05, "foreach": False},
{"lr": 0.1, "weight_decay": 0.05},
{"lr": 0.1, "foreach": True},
{"lr": 0.1, "weight_decay": 0.05, "foreach": True},
{"lr": 0.1, "weight_decay": 0.05, "amsgrad": True, "foreach": True},
{"lr": 0.1, "weight_decay": 0.05, "amsgrad": True},
{
"lr": 0.1,
"weight_decay": 0.05,
"maximize": True,
"amsgrad": True,
"foreach": True,
},
{"lr": 0.1, "fused": True},
{"lr": 0.1, "weight_decay": 0.05, "amsgrad": True, "fused": True},
Expand Down Expand Up @@ -132,16 +135,15 @@ def test_adamw_1d_sharding(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))

adamw_configs = [
{"lr": 0.1},
{"lr": 0.1, "foreach": False},
{"lr": 0.1, "weight_decay": 0.05, "foreach": False},
{"lr": 0.1, "weight_decay": 0.05},
{"lr": 0.1, "weight_decay": 0.05, "foreach": True},
{
"lr": 0.1,
"betas": (0.6, 0.66),
"eps": 1e-6,
"weight_decay": 0.05,
"amsgrad": True,
"foreach": True,
},
{
"lr": 0.1,
Expand All @@ -150,7 +152,6 @@ def test_adamw_1d_sharding(self):
"weight_decay": 0.05,
"maximize": True,
"amsgrad": True,
"foreach": True,
},
{"lr": 0.1, "weight_decay": 0.05, "fused": True},
{
Expand Down Expand Up @@ -191,24 +192,24 @@ def test_sgd_1d_sharding(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))

sgd_configs = [
{"lr": 0.1},
{"lr": 0.1, "foreach": False},
{"lr": 0.1, "momentum": 0.05, "foreach": False},
{"lr": 0.1, "momentum": 0.05},
{"lr": 0.1, "momentum": 0.05, "foreach": True},
{"lr": 0.1, "momentum": 0.06, "dampening": 0.07, "foreach": True},
{"lr": 0.1, "momentum": 0.06, "dampening": 0.07},
{
"lr": 0.1,
"momentum": 0.08,
"weight_decay": 0.05,
"nesterov": True,
"maximize": True,
"foreach": False,
},
{
"lr": 0.1,
"momentum": 0.08,
"weight_decay": 0.05,
"nesterov": True,
"maximize": True,
"foreach": True,
},
]

Expand All @@ -231,21 +232,23 @@ def test_adagrad_1d_sharding(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))

adagrad_configs = [
{"lr": 0.1},
{"lr": 0.1, "lr_decay": 0.05},
{"lr": 0.1, "lr_decay": 0.02, "weight_decay": 0.05},
{"lr": 0.1, "foreach": False},
{"lr": 0.1, "lr_decay": 0.05, "foreach": False},
{"lr": 0.1, "lr_decay": 0.02, "weight_decay": 0.05, "foreach": False},
{
"lr": 0.1,
"lr_decay": 0.02,
"weight_decay": 0.05,
"initial_accumulator_value": 0.03,
"foreach": False,
},
{
"lr": 0.1,
"lr_decay": 0.02,
"weight_decay": 0.05,
"initial_accumulator_value": 0.03,
"eps": 1e-6,
"foreach": False,
},
{
"lr": 0.1,
Expand All @@ -254,6 +257,7 @@ def test_adagrad_1d_sharding(self):
"initial_accumulator_value": 0.03,
"eps": 1e-6,
"maximize": True,
"foreach": False,
},
{
"lr": 0.1,
Expand All @@ -262,7 +266,6 @@ def test_adagrad_1d_sharding(self):
"initial_accumulator_value": 0.03,
"eps": 1e-6,
"maximize": True,
"foreach": True,
},
]

Expand All @@ -285,16 +288,23 @@ def test_RMSprop_1d_sharding(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))

RMSprop_configs = [
{"lr": 0.1},
{"lr": 0.1, "alpha": 0.85},
{"lr": 0.1, "alpha": 0.88, "eps": 1e-6},
{"lr": 0.1, "alpha": 0.88, "eps": 1e-6, "weight_decay": 0.05},
{"lr": 0.1, "foreach": False},
{"lr": 0.1, "alpha": 0.85, "foreach": False},
{"lr": 0.1, "alpha": 0.88, "eps": 1e-6, "foreach": False},
{
"lr": 0.1,
"alpha": 0.88,
"eps": 1e-6,
"weight_decay": 0.05,
"foreach": False,
},
{
"lr": 0.1,
"alpha": 0.88,
"eps": 1e-6,
"weight_decay": 0.05,
"momentum": 0.9,
"foreach": False,
},
{
"lr": 0.1,
Expand All @@ -303,6 +313,7 @@ def test_RMSprop_1d_sharding(self):
"weight_decay": 0.05,
"momentum": 0.9,
"centered": True,
"foreach": False,
},
{
"lr": 0.1,
Expand All @@ -312,6 +323,7 @@ def test_RMSprop_1d_sharding(self):
"momentum": 0.9,
"centered": True,
"maximize": True,
"foreach": False,
},
{
"lr": 0.1,
Expand All @@ -321,7 +333,6 @@ def test_RMSprop_1d_sharding(self):
"momentum": 0.9,
"centered": True,
"maximize": True,
"foreach": True,
},
]

Expand All @@ -344,23 +355,27 @@ def test_adadelta_1d_sharding(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))

adadelta_configs = [
{"lr": 0.1},
{"lr": 0.1, "rho": 0.85},
{"lr": 0.1, "rho": 0.88, "eps": 1e-5},
{"lr": 0.1, "rho": 0.88, "eps": 1e-6, "weight_decay": 0.05},
{"lr": 0.1, "foreach": False},
{"lr": 0.1, "rho": 0.85, "foreach": False},
{"lr": 0.1, "rho": 0.88, "eps": 1e-5, "foreach": False},
{
"lr": 0.1,
"rho": 0.88,
"eps": 1e-6,
"weight_decay": 0.05,
"foreach": False,
},
{
"lr": 0.1,
"rho": 0.88,
"eps": 1e-6,
"weight_decay": 0.05,
"foreach": True,
},
{
"lr": 0.1,
"rho": 0.88,
"eps": 1e-6,
"weight_decay": 0.05,
"foreach": True,
"maximize": True,
},
]
Expand All @@ -384,23 +399,21 @@ def test_nadam_1d_sharding(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))

nadam_configs = [
{"lr": 0.1},
{"lr": 0.1, "foreach": False},
{"lr": 0.1, "weight_decay": 0.05, "foreach": False},
{"lr": 0.1, "weight_decay": 0.05},
{"lr": 0.1, "weight_decay": 0.05, "foreach": True},
{
"lr": 0.1,
"betas": (0.6, 0.66),
"eps": 1e-6,
"weight_decay": 0.05,
"foreach": True,
},
{
"lr": 0.1,
"betas": (0.6, 0.66),
"eps": 1e-6,
"weight_decay": 0.05,
"decoupled_weight_decay": True,
"foreach": True,
},
]

Expand All @@ -423,23 +436,24 @@ def test_radam_1d_sharding(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))

radam_configs = [
{"lr": 0.1},
{"lr": 0.1, "weight_decay": 0.05},
{"lr": 0.1, "weight_decay": 0.05, "foreach": True},
{"lr": 0.1, "foreach": False},
{"lr": 0.1, "weight_decay": 0.05, "foreach": False},
{
"lr": 0.1,
"weight_decay": 0.05,
},
{
"lr": 0.1,
"betas": (0.6, 0.66),
"eps": 1e-6,
"weight_decay": 0.05,
"foreach": True,
},
{
"lr": 0.1,
"betas": (0.6, 0.66),
"eps": 1e-6,
"weight_decay": 0.05,
"decoupled_weight_decay": True,
"foreach": True,
},
]

Expand All @@ -462,23 +476,27 @@ def test_adamax_1d_sharding(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))

adamax_configs = [
{"lr": 0.1},
{"lr": 0.1, "betas": (0.6, 0.66)},
{"lr": 0.1, "betas": (0.6, 0.66), "eps": 1e-6},
{"lr": 0.1, "betas": (0.6, 0.66), "eps": 1e-6, "weight_decay": 0.05},
{"lr": 0.1, "foreach": False},
{"lr": 0.1, "betas": (0.6, 0.66), "foreach": False},
{"lr": 0.1, "betas": (0.6, 0.66), "eps": 1e-6, "foreach": False},
{
"lr": 0.1,
"betas": (0.6, 0.66),
"eps": 1e-6,
"weight_decay": 0.05,
"foreach": False,
},
{
"lr": 0.1,
"betas": (0.6, 0.66),
"eps": 1e-6,
"weight_decay": 0.05,
"foreach": True,
},
{
"lr": 0.1,
"betas": (0.6, 0.66),
"eps": 1e-6,
"weight_decay": 0.05,
"foreach": True,
"maximize": True,
},
]
Expand All @@ -502,11 +520,18 @@ def test_asgd_1d_sharding(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))

asgd_configs = [
{"lr": 0.1},
{"lr": 0.1, "lambd": 0.001},
{"lr": 0.1, "lambd": 0.001, "alpha": 0.85},
{"lr": 0.1, "lambd": 0.001, "alpha": 0.85, "t0": 1e5},
{"lr": 0.1, "lambd": 0.001, "alpha": 0.85, "t0": 1e5, "weight_decay": 0.05},
{"lr": 0.1, "foreach": False},
{"lr": 0.1, "lambd": 0.001, "foreach": False},
{"lr": 0.1, "lambd": 0.001, "alpha": 0.85, "foreach": False},
{"lr": 0.1, "lambd": 0.001, "alpha": 0.85, "t0": 1e5, "foreach": False},
{
"lr": 0.1,
"lambd": 0.001,
"alpha": 0.85,
"t0": 1e5,
"weight_decay": 0.05,
"foreach": False,
},
{
"lr": 0.1,
"lambd": 0.001,
Expand Down
7 changes: 5 additions & 2 deletions test/distributed/tensor/parallel/test_tp_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,11 @@ def test_transformer_training(self, is_seq_parallel=False):

# Ensure model weights are still the same after update.
optim.step()
with CommDebugMode() as comm_mode:
optim_tp.step()
from torch.distributed._tensor.experimental import implicit_replication

with implicit_replication():
with CommDebugMode() as comm_mode:
optim_tp.step()
self._check_module(model, model_tp)
if is_seq_parallel:
self.assertDictEqual(
Expand Down
8 changes: 8 additions & 0 deletions torch/distributed/_tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from torch.distributed._tensor.ops.utils import normalize_to_torch_size
from torch.distributed._tensor.placement_types import Placement, Replicate, Shard
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh
from torch.optim.optimizer import _foreach_supported_types


# All public APIs from dtensor package
__all__ = [
Expand All @@ -23,6 +25,12 @@
]


# Append DTensor to the list of supported types for foreach implementation of optimizer
# so that we will try to use foreach over the for-loop implementation on CUDA.
if DTensor not in _foreach_supported_types:
_foreach_supported_types.append(DTensor)


def _dtensor_init_helper(
init_op,
size: torch.Size,
Expand Down
1 change: 1 addition & 0 deletions torch/distributed/_tensor/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def unwrap_to_op_info(
args_schema.append(arg._spec)
local_args.append(arg._local_tensor)
if mesh is not None:
print(f"{mesh=}, {arg.device_mesh=}")
if mesh != arg.device_mesh:
raise NotImplementedError(
f"{op_call}: DTensor does not support cross-mesh operation yet!"
Expand Down

0 comments on commit 240236d

Please sign in to comment.