Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport Enable eager spmd #7341 and fix eager mode spmd module loading with fsdpv2 #7631 #7673

Merged
merged 2 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions examples/eager/train_decoder_only_eager_spmd_data_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import sys
import os
example_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0])))
sys.path.append(example_folder)
from train_decoder_only_base import TrainDecoderOnlyBase

import numpy as np

import torch
import torch_xla
import torch_xla.distributed.spmd as xs
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
from torch_xla import runtime as xr

# Enable the SPMD
xr.use_spmd()


# More detailed examaple can be found in https://github.com/pytorch/xla/blob/master/test/spmd/test_train_spmd_imagenet.py
# Check out our user guide in https://github.com/pytorch/xla/blob/master/docs/spmd.md
class TrainDecoderSpmdDDP(TrainDecoderOnlyBase):

def __init__(self):
super().__init__()
# Shard along batch dimension only
num_devices = xr.global_runtime_device_count()
device_ids = np.arange(num_devices)
mesh_shape = (num_devices,)
mesh = xs.Mesh(device_ids, mesh_shape, ('data',))
# scale the batch size with num_devices since there will be only one
# process that handles all runtime devices.
self.batch_size *= num_devices

train_loader = xu.SampleGenerator(
data=(torch.zeros(self.batch_size, self.seq_len, dtype=torch.int64),
torch.zeros(self.batch_size, self.seq_len, dtype=torch.int64)),
sample_count=self.train_dataset_len // self.batch_size)
self.train_device_loader = pl.MpDeviceLoader(
train_loader,
self.device,
# Shard the input's batch dimension along the `data` axis, no sharding along other dimensions
input_sharding=xs.ShardingSpec(mesh, ('data', None)))


if __name__ == '__main__':
torch_xla.experimental.eager_mode(True)
spmd_ddp = TrainDecoderSpmdDDP()
spmd_ddp.start_training()
65 changes: 65 additions & 0 deletions test/eager/test_eager_spmd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import unittest
import sys

import torch
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.spmd as xs
import numpy as np


class MultiLinear(torch.nn.Module):

def __init__(self):
super(MultiLinear, self).__init__()
self.linear1 = torch.nn.Linear(10, 20)
self.linear2 = torch.nn.Linear(20, 30)
self.linear3 = torch.nn.Linear(30, 40)

def forward(self, input):
return self.linear3(self.linear2(self.linear1(input)))


class Eager(unittest.TestCase):

@classmethod
def setUpClass(cls):
torch_xla.experimental.eager_mode(True)
xr.use_spmd()
cls.n_devices = xr.global_runtime_device_count()
cls.device_ids = np.array(range(cls.n_devices))

def _get_mesh(self, mesh_shape, device_ids=None, axis_names=None):
assert type(mesh_shape) is tuple, 'mesh_shape must be Tuple[int]'
if device_ids is None:
device_ids = self.device_ids
assert len(device_ids) == self.n_devices
return xs.Mesh(device_ids, mesh_shape, axis_names)

def test_eager_spmd_basic(self):
device = torch_xla.device()
mesh = self._get_mesh((self.n_devices,), axis_names=('data',))
torch.manual_seed(100)
linear = torch.nn.Linear(10, 20)
input = torch.randn(8, 10)
input_xla = input.to(device)
xs.mark_sharding(input_xla, mesh, ('data', None))
res = linear(input)
linear.to(device)
res_xla = linear(input_xla)
self.assertTrue(torch.allclose(res, res_xla.cpu(), atol=1e-2))

def test_module_to_empty_sharding(self):
device = torch_xla.device()
mlinear = MultiLinear()
mlinear.to(device)
torch_xla._XLAC._get_xla_sharding_spec(mlinear.linear1.weight)
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(mlinear.linear1.weight), '')


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ function run_xla_op_tests2 {
run_test "$CDIR/eager/test_eager_with_xla_compile.py"
run_test "$CDIR/eager/test_eager_with_torch_compile.py"
run_test "$CDIR/eager/test_eager_all_reduce_in_place.py"
run_test "$CDIR/eager/test_eager_spmd.py"
}

# All the new xla op tests should go to run_xla_op_tests3
Expand Down
3 changes: 3 additions & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ python3 test/test_pallas.py
python3 test/test_pallas_spmd.py
python3 test/test_input_output_aliases.py
python3 test/test_gmm.py
python3 test/eager/test_eager_spmd.py
python3 test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py
python3 test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py
python3 test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py
Expand All @@ -39,9 +40,11 @@ python3 examples/train_resnet_amp.py

# HACK: don't confuse local `torch_xla` folder with installed package
# Python 3.11 has the permanent fix: https://stackoverflow.com/a/73636559
# Egaer tests will take more HBM, only run them on TPU v4 CI
TPU_VERSION=$(python -c "import sys; sys.path.remove(''); import torch_xla; print(torch_xla._internal.tpu.version())")
if [[ -n "$TPU_VERSION" && "$TPU_VERSION" == "4" ]]; then
python3 examples/eager/train_decoder_only_eager.py
python3 examples/eager/train_decoder_only_eager_spmd_data_parallel.py
python3 examples/eager/train_decoder_only_eager_with_compile.py
python3 examples/eager/train_decoder_only_eager_multi_process.py
fi
24 changes: 18 additions & 6 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1503,13 +1503,23 @@ at::Tensor XLANativeFunctions::empty_symint(
// does not actually end up doing any memory initialization, we use that and
// avoid going to CPU for it. A common PT pattern is indeed doing empty() plus
// s_copy_().
XLATensorPtr xla_tensor;
if (all_dims_static) {
return bridge::AtenFromXlaTensor(tensor_methods::full(
XlaHelpers::I64List(int_sizes.value()), 0,
GetXlaDeviceOrCurrent(device), at::dtype_or_default(dtype)));
xla_tensor = tensor_methods::full(XlaHelpers::I64List(int_sizes.value()), 0,
GetXlaDeviceOrCurrent(device),
at::dtype_or_default(dtype));
} else {
xla_tensor =
tensor_methods::full_symint(sym_size, 0, GetXlaDeviceOrCurrent(device),
at::dtype_or_default(dtype));
}
// `tensor.to` will trigger an `empty` + `_to_copy`. In the egaer mode, the
// `full` will be evulated eagerly and got a replicated sharding. We should
// leave the sharding to be empty.
if (XLAGraphExecutor::Get()->UseEagerMode() && UseVirtualDevice()) {
xla_tensor->ClearShardingSpec();
}
return bridge::AtenFromXlaTensor(tensor_methods::full_symint(
sym_size, 0, GetXlaDeviceOrCurrent(device), at::dtype_or_default(dtype)));
return bridge::AtenFromXlaTensor(xla_tensor);
}

at::Tensor XLANativeFunctions::empty_strided_symint(
Expand Down Expand Up @@ -2742,7 +2752,9 @@ void XLANativeFunctions::_propagate_xla_data(const at::Tensor& input,

// 2) Aid SPMD.
XLATensor::ShardingSpecPtr sharding = input_tensor->sharding_spec();
if (sharding && sharding->sharding.type() != xla::OpSharding::UNKNOWN) {
// don't propagate sharding in eager mode.
if (!XLAGraphExecutor::Get()->UseEagerMode() && sharding &&
sharding->sharding.type() != xla::OpSharding::UNKNOWN) {
tensor_methods::custom_sharding_(output_tensor,
input_tensor->sharding_spec());
}
Expand Down
Loading