Skip to content

Commit e42fffa

Browse files
authored
[SPMD][Virtual Device]All tensors should be in SPMD:0 C++ device (#5284)
* Move all tensors to SPMD:0 C++ device under spmd context * fix load shards * fix test_mark_sharding_2d by not creating placeholder for virtual device * fix the waitdeviceop for spmd case * Fix test_shard_hashing * fix spmd device casting issue * remove hacks in test_xla_virtual_device.py * add test for new virtual device usage * fix review comments * fix IsTpuDevice * linter
1 parent 44033ed commit e42fffa

14 files changed

+150
-90
lines changed

test/spmd/test_xla_sharding.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,6 @@ def test_execute_replicated_metrics(self):
314314
xt = torch.ones(2, 2).to(xm.xla_device())
315315
xs.mark_sharding(xt, self._get_mesh((1, self.n_devices)), (0, 1))
316316
xt += 2
317-
sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(xt)
318317
xm.mark_step()
319318
xm.wait_device_ops()
320319
self.assertEqual(met.metric_data('ExecuteReplicatedTime')[0], 1)

test/spmd/test_xla_virtual_device.py

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,6 @@ def test_outbound_data_metrics(self):
7979

8080
def test_non_tensor_scalar(self):
8181
sharding_spec = xs.ShardingSpec(self._get_mesh((1, self.n_devices)), (0, 1))
82-
# TODO(JackCaoG)currently, execution will only happen if there is at least one
83-
# tensor on non-spmd:0 device.
84-
t1 = torch.randn(3, 3, device=xm.xla_device())
8582
# tensor will have device as `SPMD:0` in c++
8683
xt1 = xm.send_cpu_data_to_device([torch.randn(3, 3)],
8784
xm.xla_device(),
@@ -95,9 +92,6 @@ def test_non_tensor_scalar(self):
9592
def test_mark_step_on_virtual_device(self):
9693
xm.mark_step()
9794
sharding_spec = xs.ShardingSpec(self._get_mesh((1, self.n_devices)), (0, 1))
98-
# TODO(JackCaoG)currently, execution will only happen if there is at least one
99-
# tensor on non-spmd:0 device.
100-
t1 = torch.randn(3, 3, device=xm.xla_device())
10195
# tensor will have device as `SPMD:0` in c++
10296
xt1 = xm.send_cpu_data_to_device([torch.randn(3, 3)],
10397
xm.xla_device(),
@@ -108,6 +102,63 @@ def test_mark_step_on_virtual_device(self):
108102
self.assertNotIn('aten::div',
109103
torch_xla._XLAC._get_xla_tensor_debug_info(xt2))
110104

105+
def test_virtual_device_no_upload(self):
106+
met.clear_all()
107+
device = xm.xla_device()
108+
t1 = torch.randn(5, 5).to(device)
109+
t1_debug_info = torch_xla._XLAC._get_xla_tensor_debug_info(t1)
110+
# t1's upload to device should be deferred
111+
self.assertIn("Tensor on host: with size [5, 5]", t1_debug_info)
112+
self.assertNotIn("TransferToServerTime", met.metric_names())
113+
# t1 should be on SPMD device under spmd context
114+
self.assertIn("Device: SPMD:0", t1_debug_info)
115+
self.assertIn("IR: None", t1_debug_info)
116+
self.assertIn("XLAData: None", t1_debug_info)
117+
118+
def test_virtual_device_upload_after_mark_sharding(self):
119+
met.clear_all()
120+
partition_spec = (0, 1)
121+
device = xm.xla_device()
122+
t1 = torch.randn(8, 8).to(device)
123+
t1_debug_info = torch_xla._XLAC._get_xla_tensor_debug_info(t1)
124+
self.assertIn("Tensor on host: with size [8, 8]", t1_debug_info)
125+
xs.mark_sharding(t1, self._get_mesh((1, self.n_devices)), partition_spec)
126+
t1_debug_info_new = torch_xla._XLAC._get_xla_tensor_debug_info(t1)
127+
# tensor should be uploaded to device after mark_sharding
128+
self.assertIn("Tensor on host: None", t1_debug_info_new)
129+
self.assertIn("xla::device_data", t1_debug_info_new)
130+
self.assertIn("XLAShardedData", t1_debug_info_new)
131+
self.assertIn("TransferToServerTime", met.metric_names())
132+
133+
def test_virtual_device_upload_after_tracing(self):
134+
met.clear_all()
135+
device = xm.xla_device()
136+
t1 = torch.randn(8, 8).to(device)
137+
t1_debug_info = torch_xla._XLAC._get_xla_tensor_debug_info(t1)
138+
self.assertIn("Tensor on host: with size [8, 8]", t1_debug_info)
139+
t2 = t1 + t1
140+
t1_debug_info_new = torch_xla._XLAC._get_xla_tensor_debug_info(t1)
141+
# tensor should be uploaded to device after being used as input to other op.
142+
self.assertIn("Tensor on host: None", t1_debug_info_new)
143+
self.assertIn("xla::device_data", t1_debug_info_new)
144+
self.assertIn("TransferToServerTime", met.metric_names())
145+
146+
def test_virtual_device_upload_for_sharded_dataloader(self):
147+
met.clear_counters()
148+
device = xm.xla_device()
149+
sharding_spec = xs.ShardingSpec(self._get_mesh((1, self.n_devices)), (0, 1))
150+
# tensor will have device as `SPMD:0` in c++
151+
t1 = xm.send_cpu_data_to_device([torch.randn(8, 8)],
152+
device,
153+
input_sharding=sharding_spec)[0]
154+
t1_debug_info = torch_xla._XLAC._get_xla_tensor_debug_info(t1)
155+
self.assertIn("Device: SPMD:0", t1_debug_info)
156+
# tensor should be uploaded to device after send_cpu_data_to_device + sharding_spec
157+
self.assertIn("Tensor on host: None", t1_debug_info)
158+
self.assertIn("xla::device_data", t1_debug_info)
159+
self.assertIn("XLAShardedData", t1_debug_info)
160+
self.assertIn("TransferToServerTime", met.metric_names())
161+
111162

112163
if __name__ == '__main__':
113164
test = unittest.main()

torch_xla/csrc/aten_xla_bridge.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,15 @@ class AtenXlaDeviceMapper {
3939

4040
private:
4141
AtenXlaDeviceMapper() {
42-
for (auto& device_str :
43-
torch_xla::runtime::GetComputationClient()->GetLocalDevices()) {
44-
devices_.emplace_back(ParseDeviceString(device_str));
45-
devices_ordinals_[devices_.back()] = devices_.size() - 1;
42+
if (UseVirtualDevice()) {
43+
devices_.emplace_back(ParseDeviceString("SPMD:0"));
44+
devices_ordinals_[devices_.back()] = 0;
45+
} else {
46+
for (auto& device_str :
47+
torch_xla::runtime::GetComputationClient()->GetLocalDevices()) {
48+
devices_.emplace_back(ParseDeviceString(device_str));
49+
devices_ordinals_[devices_.back()] = devices_.size() - 1;
50+
}
4651
}
4752
}
4853

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ at::Tensor XLANativeFunctions::_copy_from(const at::Tensor& self,
467467
if (!self_tensor) {
468468
static bool sync_update =
469469
runtime::sys_util::GetEnvBool("XLA_TENSOR_UPDATE_SYNC", true) &&
470-
!ShardingUtil::UseVirtualDevice();
470+
!UseVirtualDevice();
471471
XLA_CHECK(dst_tensor);
472472
dst_tensor->UpdateFromTensor(self, /*sync=*/sync_update);
473473
} else if (!dst_tensor) {

torch_xla/csrc/device.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ std::string DeviceType::toString() const {
3737
torch::lazy::BackendDevice ParseDeviceString(const std::string& device_spec) {
3838
if (device_spec.empty()) {
3939
std::string default_device_spec =
40-
runtime::GetComputationClient()->GetDefaultDevice();
40+
UseVirtualDevice()
41+
? "SPMD:0"
42+
: runtime::GetComputationClient()->GetDefaultDevice();
4143
XLA_CHECK(!default_device_spec.empty());
4244
return ParseDeviceString(default_device_spec);
4345
}
@@ -101,4 +103,18 @@ torch::lazy::BackendDevice SetCurrentDevice(
101103
return current;
102104
}
103105

106+
bool ShouldUseVirtualDevice() {
107+
bool use_virtual_device =
108+
runtime::sys_util::GetEnvBool("XLA_USE_SPMD", false);
109+
if (use_virtual_device) {
110+
TF_LOG(INFO) << "Using SPMD virtual device optimization";
111+
}
112+
return use_virtual_device;
113+
}
114+
115+
bool UseVirtualDevice() {
116+
static bool use_virtual_device = ShouldUseVirtualDevice();
117+
return use_virtual_device;
118+
}
119+
104120
} // namespace torch_xla

torch_xla/csrc/device.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ static inline torch::lazy::BackendDevice GetDeviceOrCurrent(
4242
return device != nullptr ? *device : GetCurrentDevice();
4343
}
4444

45+
// Test whether the XLA_USE_SPMD environment variable is set to enable the
46+
// virtual device optimization.
47+
bool UseVirtualDevice();
48+
4549
} // namespace torch_xla
4650

4751
#endif // XLA_TORCH_XLA_CSRC_DEVICE_H_

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ std::string GetXLATensorDebugInfo(const at::Tensor& tensor) {
384384
auto at_tensor = xtensor->CurrentTensorData();
385385
ss << "Tensor on host: ";
386386
if (at_tensor) {
387-
ss << " with size " << at_tensor->sizes() << "\n";
387+
ss << "with size " << at_tensor->sizes() << "\n";
388388
} else {
389389
ss << "None\n";
390390
}
@@ -1126,7 +1126,7 @@ void InitXlaModuleBindings(py::module m) {
11261126
[](const std::vector<std::string>& devices) {
11271127
NoGilSection nogil;
11281128
XLAGraphExecutor::Get()->WaitDeviceOps(devices);
1129-
if (ShardingUtil::UseVirtualDevice()) {
1129+
if (UseVirtualDevice()) {
11301130
std::vector<std::string> spmd_device = {"SPMD:0"};
11311131
runtime::GetComputationClient()->WaitDeviceOps(spmd_device);
11321132
} else {
@@ -1313,8 +1313,7 @@ void InitXlaModuleBindings(py::module m) {
13131313
const py::list& group_assignment, const py::list& replication_groups,
13141314
int sharding_type) {
13151315
TORCH_LAZY_COUNTER("XlaMarkSharding", 1);
1316-
XLA_CHECK(ShardingUtil::UseVirtualDevice())
1317-
<< "Please set `XLA_USE_SPMD=1`";
1316+
XLA_CHECK(UseVirtualDevice()) << "Please set `XLA_USE_SPMD=1`";
13181317
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
13191318
xla::OpSharding sharding = ShardingUtil::CreateOpSharding(
13201319
tile_assignment, group_assignment, replication_groups,
@@ -1393,23 +1392,33 @@ void InitXlaModuleBindings(py::module m) {
13931392
// shape. Note that this padding is _not_ included in the global indices
13941393
// returned by `_get_local_shard_indices`.
13951394
m.def("_get_local_shards",
1396-
[](const at::Tensor& input) -> std::vector<at::Tensor> {
1395+
[](const at::Tensor& input)
1396+
-> std::tuple<std::vector<at::Tensor>, std::vector<std::string>> {
13971397
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
13981398
XLA_CHECK(xtensor->GetXlaData() != nullptr)
13991399
<< "Shard data is not available";
14001400
XLA_CHECK(xtensor->sharding_spec() != nullptr)
14011401
<< "Tensor is not sharded";
1402-
XLA_CHECK(ShardingUtil::UseVirtualDevice())
1402+
XLA_CHECK(UseVirtualDevice())
14031403
<< "Virtual device must be enabled to use _get_local_shards";
14041404
auto handle = UnwrapXlaData(xtensor->GetXlaData());
1405-
auto shard_handles =
1405+
std::vector<runtime::ComputationClient::DataPtr> shard_handles =
14061406
runtime::GetComputationClient()->GetDataShards(handle);
14071407
std::vector<at::Tensor> shards;
1408-
for (auto& shard_handle : shard_handles) {
1409-
auto xshard = XLATensor::Create(WrapXlaData(shard_handle));
1410-
shards.push_back(bridge::AtenFromXlaTensor(std::move(xshard)));
1408+
std::vector<std::string> str_devices;
1409+
shards.reserve(shard_handles.size());
1410+
str_devices.reserve(shard_handles.size());
1411+
// Tansfer shards from the device and create cpu tensors.
1412+
for (const runtime::ComputationClient::DataPtr shard_handle :
1413+
shard_handles) {
1414+
shards.push_back(
1415+
XlaDataToTensors(
1416+
{WrapXlaData(shard_handle)},
1417+
TensorTypeFromXlaType(shard_handle->shape().element_type()))
1418+
.front());
1419+
str_devices.push_back(shard_handle->device());
14111420
}
1412-
return shards;
1421+
return std::make_tuple(shards, str_devices);
14131422
});
14141423
// Returns the indices of the shards into the global tensor as either
14151424
// a Python list of slices for each dimension or a Python Ellipsis object
@@ -1478,8 +1487,7 @@ void InitXlaModuleBindings(py::module m) {
14781487
<< "Input shard shape must include padding: " << shard.sizes()
14791488
<< " vs " << shard_shape;
14801489
}
1481-
auto xla_devices = GetXlaDevices(devices);
1482-
auto xla_data = ShardingUtil::CreateShardedData(shards, xla_devices,
1490+
auto xla_data = ShardingUtil::CreateShardedData(shards, devices,
14831491
xtensor->shape(), sharding);
14841492
xtensor->SetXlaData(WrapXlaData(xla_data));
14851493
});
@@ -1677,8 +1685,8 @@ void InitXlaModuleBindings(py::module m) {
16771685
torch::lazy::hash_t hash = *(torch::lazy::hash_t*)(hash_str.c_str());
16781686
// Device will be Virtual device if SPMD is enabled.
16791687
torch::lazy::BackendDevice device =
1680-
ShardingUtil::UseVirtualDevice() ? ParseDeviceString("SPMD:0")
1681-
: torch_xla::GetCurrentDevice();
1688+
UseVirtualDevice() ? ParseDeviceString("SPMD:0")
1689+
: torch_xla::GetCurrentDevice();
16821690
auto results = XLAGraphExecutor::Get()->ExecuteComputationWithBarrier(
16831691
hash, graph_inputs, device);
16841692
std::vector<at::Tensor> retlist;

torch_xla/csrc/runtime/pjrt_computation_client.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,10 @@ PjRtComputationClient::ExecuteComputation(
533533
<< device;
534534
// Grab the shared lock and block the `WaitDeviceOps` until buffer is
535535
// ready.
536+
// TODO(JackCaoG): This lock should acquired outside of the lockfn and
537+
// passed in. It is possible that lockfn started after ExecuteComputation
538+
// released the xla_graph_executor lock, which will create a short windows
539+
// where device is unlcoked while execution is still running.
536540
auto lock = lock_device_shared(device);
537541
TF_VLOG(5) << "ExecuteComputation acquiring PJRT device lock for " << device
538542
<< " Done";

torch_xla/csrc/tensor.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,7 @@ torch::lazy::Value XLATensor::GetIrValue() const {
338338
c10::optional<at::Tensor> tensor_data = CurrentTensorData();
339339
XLA_CHECK(tensor_data);
340340
AssignIrValue(GetIrValueForTensor(*tensor_data, GetDevice()));
341+
data()->tensor_data = c10::nullopt;
341342
return data()->ir_value;
342343
}
343344

@@ -492,9 +493,8 @@ void XLATensor::SetTensor(at::Tensor tensor) {
492493
}
493494

494495
void XLATensor::UpdateFromTensor(at::Tensor tensor, bool sync) {
495-
torch::lazy::BackendDevice device = ShardingUtil::UseVirtualDevice()
496-
? ParseDeviceString("SPMD:0")
497-
: GetDevice();
496+
torch::lazy::BackendDevice device =
497+
UseVirtualDevice() ? ParseDeviceString("SPMD:0") : GetDevice();
498498
if (sync) {
499499
at::Tensor typed_tensor =
500500
torch::lazy::CopyTensor(tensor, dtype(), /*copy=*/false);

torch_xla/csrc/tensor_util.cpp

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -106,13 +106,20 @@ bool Use32BitLong() {
106106
return use_32bit_long;
107107
}
108108

109+
bool IsTpuDevice(XlaDeviceType hw_type) {
110+
static bool spmd_device_is_tpu =
111+
(hw_type == XlaDeviceType::SPMD) &&
112+
runtime::GetComputationClient()->GetDefaultDevice().find("TPU") == 0;
113+
return (hw_type == XlaDeviceType::TPU) || spmd_device_is_tpu;
114+
}
115+
109116
xla::PrimitiveType XlaTypeFromTensorType(
110117
at::ScalarType scalar_type, const torch::lazy::BackendDevice& device) {
111118
XlaDeviceType hw_type = static_cast<XlaDeviceType>(device.type());
112119
switch (scalar_type) {
113120
case at::ScalarType::Double:
114-
return hw_type != XlaDeviceType::TPU ? xla::PrimitiveType::F64
115-
: xla::PrimitiveType::F32;
121+
return !IsTpuDevice(hw_type) ? xla::PrimitiveType::F64
122+
: xla::PrimitiveType::F32;
116123
case at::ScalarType::Float:
117124
return xla::PrimitiveType::F32;
118125
case at::ScalarType::BFloat16:
@@ -600,19 +607,7 @@ torch::lazy::BackendDataPtr TensorToXlaData(
600607
const at::Tensor& tensor, const xla::Shape& shape,
601608
const torch::lazy::BackendDevice& device) {
602609
TORCH_LAZY_TIMED("TensorToData");
603-
if (ShardingUtil::UseVirtualDevice()) {
604-
// Scalar value will be replicated, no need to delay the transfer here.
605-
// TODO(JackCaoG): fix this for more general cases.
606-
if (device.type() == (int8_t)XlaDeviceType::SPMD && shape.rank() > 0) {
607-
// When SPMD is enabled, we want to delay the data transfer for XLA
608-
// tensors until the data is sharded. So, we skip the data transfer
609-
// here and simply return a placeholder for the backend data ptr.
610-
// Data will only be transferred via CreateTensorsData, when users
611-
// call the mark_sharding API.
612-
return WrapXlaData(runtime::GetComputationClient()->CreateDataPlaceholder(
613-
"SPMD:0", shape));
614-
}
615-
610+
if (UseVirtualDevice()) {
616611
// The tensor is bypassing the virtual device, so it should be replicated
617612
// to all devices.
618613
std::vector<std::string> local_devices =
@@ -856,7 +851,7 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
856851
TORCH_LAZY_TIMED("TensorToData");
857852
XLA_CHECK_EQ(tensors.size(), devices.size());
858853

859-
if (ShardingUtil::UseVirtualDevice()) {
854+
if (UseVirtualDevice()) {
860855
// When running in SPMD mode, tensors here in the unsharded
861856
// CreateTensorsData should be implicitly replicated to all devices.
862857
// This case should always apply when using SPMD regardless
@@ -936,7 +931,7 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
936931

937932
std::vector<runtime::ComputationClient::TensorSource> source_tensors; // in
938933
std::vector<runtime::ComputationClient::DataPtr> new_handles; // out
939-
if (ShardingUtil::UseVirtualDevice()) {
934+
if (UseVirtualDevice()) {
940935
// GetLocalDevices returns the list of local devices specified by their
941936
// global ordinals (e.g. ["TPU:4", "TPU:5", "TPU:6", "TPU:7"]).
942937
std::vector<std::string> local_devices =
@@ -1160,27 +1155,27 @@ xla::PrimitiveType GetDevicePrimitiveType(
11601155
if (DowncastBF16() || DowncastF16()) {
11611156
return xla::PrimitiveType::F32;
11621157
}
1163-
return hw_type != XlaDeviceType::TPU ? xla::PrimitiveType::F64
1164-
: xla::PrimitiveType::F32;
1158+
return !IsTpuDevice(hw_type) ? xla::PrimitiveType::F64
1159+
: xla::PrimitiveType::F32;
11651160
case xla::PrimitiveType::F32:
11661161
if (UseF16() || DowncastF16()) {
11671162
return xla::PrimitiveType::F16;
11681163
}
11691164
return UseBF16() || DowncastBF16() ? xla::PrimitiveType::BF16
11701165
: xla::PrimitiveType::F32;
11711166
case xla::PrimitiveType::U16:
1172-
return hw_type != XlaDeviceType::TPU ? xla::PrimitiveType::U16
1173-
: xla::PrimitiveType::U32;
1167+
return !IsTpuDevice(hw_type) ? xla::PrimitiveType::U16
1168+
: xla::PrimitiveType::U32;
11741169
case xla::PrimitiveType::S16:
1175-
return hw_type != XlaDeviceType::TPU ? xla::PrimitiveType::S16
1176-
: xla::PrimitiveType::S32;
1170+
return !IsTpuDevice(hw_type) ? xla::PrimitiveType::S16
1171+
: xla::PrimitiveType::S32;
11771172
case xla::PrimitiveType::S64:
11781173
return Use32BitLong() ? xla::PrimitiveType::S32 : xla::PrimitiveType::S64;
11791174
case xla::PrimitiveType::U64:
11801175
return Use32BitLong() ? xla::PrimitiveType::U32 : xla::PrimitiveType::U64;
11811176
case xla::PrimitiveType::C128:
1182-
return hw_type != XlaDeviceType::TPU ? xla::PrimitiveType::C128
1183-
: xla::PrimitiveType::C64;
1177+
return !IsTpuDevice(hw_type) ? xla::PrimitiveType::C128
1178+
: xla::PrimitiveType::C64;
11841179
default:
11851180
return type;
11861181
}

0 commit comments

Comments
 (0)