Skip to content

Commit ce1054f

Browse files
authored
Apply sharding when creating tensors data in data loader (#4995)
1 parent ede1ad0 commit ce1054f

File tree

4 files changed

+99
-24
lines changed

4 files changed

+99
-24
lines changed

test/spmd/test_xla_sharding.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch.optim as optim
1010
import torch_xla
1111
import torch_xla.core.xla_model as xm
12+
import torch_xla.debug.metrics as met
1213
import torch_xla.experimental.xla_sharding as xs
1314
from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor
1415
import test_xla_sharding_base
@@ -170,6 +171,30 @@ def test_transfer_sharded_data_to_host(self):
170171
t1 = xt1.cpu()
171172
self.assertTrue(torch.allclose(t1, torch.ones(16, 16)))
172173

174+
def test_send_cpu_data_to_device_with_sharding(self):
175+
xm.mark_step() # Execute pending graph to avoid contaminating metrics
176+
met.clear_all()
177+
tensor = torch.arange(16, dtype=torch.float32).reshape(4, 4)
178+
mesh = self._get_mesh((1, self.n_devices))
179+
180+
# Create a ShardingSpec and use it to shard the tensor while sending to device
181+
sharding_spec = xs.ShardingSpec(mesh, (0, 1))
182+
self.assertTrue(sharding_spec.can_apply(tensor))
183+
xtensors = xm.send_cpu_data_to_device([tensor],
184+
xm.xla_device(),
185+
input_sharding=sharding_spec)
186+
self.assertEqual(len(xtensors), 1)
187+
outbound = met.metric_data("OutboundData")[1]
188+
self.assertEqual(outbound, tensor.element_size() * tensor.nelement())
189+
190+
# Verify the resulting sharding annotation matches an explicit `mark_sharding` call
191+
xt = xtensors[0]
192+
explicit_xt = tensor.to(xm.xla_device())
193+
xs.mark_sharding(explicit_xt, mesh, (0, 1))
194+
self.assertEqual(
195+
torch_xla._XLAC._get_xla_sharding_spec(xt),
196+
torch_xla._XLAC._get_xla_sharding_spec(explicit_xt))
197+
173198

174199
if __name__ == '__main__':
175200
test = unittest.main()

torch_xla/core/xla_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -967,11 +967,11 @@ def send_cpu_data_to_device(data, device, input_sharding=None):
967967

968968
def convert_fn(tensors):
969969
devices = [str(device)] * len(tensors)
970-
xtensors = torch_xla._XLAC._xla_tensors_from_aten(tensors, devices)
970+
shardings = None
971971
if input_sharding:
972-
for xtensor in xtensors:
973-
if input_sharding.can_apply(xtensor):
974-
input_sharding.apply(xtensor)
972+
shardings = [input_sharding.xla_spec(t) for t in tensors]
973+
xtensors = torch_xla._XLAC._xla_tensors_from_aten(tensors, devices,
974+
shardings)
975975
return xtensors
976976

977977
def select_fn(v):

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -410,13 +410,25 @@ std::ptrdiff_t GetTensorId(const at::Tensor& tensor) {
410410

411411
std::vector<at::Tensor> GetXlaTensorsFromAten(
412412
const std::vector<at::Tensor>& aten_tensors,
413-
const std::vector<std::string>& devices) {
414-
auto data_handles = CreateTensorsData(aten_tensors, GetXlaDevices(devices));
413+
const std::vector<std::string>& devices,
414+
const std::optional<std::vector<XLATensor::ShardingSpecPtr>>
415+
sharding_specs) {
416+
std::vector<std::shared_ptr<torch::lazy::BackendData>> data_handles;
417+
if (sharding_specs.has_value()) {
418+
data_handles = CreateTensorsData(aten_tensors, sharding_specs.value(),
419+
GetXlaDevices(devices));
420+
} else {
421+
data_handles = CreateTensorsData(aten_tensors, GetXlaDevices(devices));
422+
}
415423

416424
std::vector<at::Tensor> xla_tensors;
417425
xla_tensors.reserve(data_handles.size());
418-
for (auto& data_handle : data_handles) {
426+
for (int i = 0; i < data_handles.size(); i++) {
427+
auto& data_handle = data_handles[i];
419428
XLATensorPtr xla_tensor = XLATensor::Create(std::move(data_handle));
429+
if (sharding_specs.has_value() && sharding_specs.value()[i] != nullptr) {
430+
xla_tensor->SetShardingSpec(*sharding_specs.value()[i]);
431+
}
420432
xla_tensors.push_back(bridge::AtenFromXlaTensor(std::move(xla_tensor)));
421433
}
422434
return xla_tensors;
@@ -904,21 +916,36 @@ void InitXlaModuleBindings(py::module m) {
904916
[](const std::vector<at::Tensor>& tensors) -> std::string {
905917
return GetTensorsHloGraph(tensors);
906918
});
907-
m.def("_xla_tensors_from_aten", [](const std::vector<at::Tensor>& tensors,
908-
const std::vector<std::string>& devices) {
909-
std::vector<at::Tensor> result;
910-
{
911-
NoGilSection nogil;
912-
std::vector<at::Tensor> xla_tensors =
913-
GetXlaTensorsFromAten(tensors, devices);
914-
result.reserve(xla_tensors.size());
915-
for (size_t i = 0; i < xla_tensors.size(); ++i) {
916-
result.push_back(torch::autograd::make_variable(
917-
xla_tensors[i], /*requires_grad=*/tensors.at(i).requires_grad()));
918-
}
919-
}
920-
return result;
921-
});
919+
py::class_<XLATensor::ShardingSpec, XLATensor::ShardingSpecPtr>(
920+
m, "XlaShardingSpec")
921+
.def(py::init([](at::Tensor tensor, py::list& tile_assignment,
922+
bool replicated, bool manual) {
923+
auto op_sharding =
924+
ShardingUtil::CreateOpSharding(tile_assignment, replicated, manual);
925+
auto shape = CreateComputationShapeFromTensor(tensor, nullptr);
926+
return std::make_shared<XLATensor::ShardingSpec>(op_sharding, shape);
927+
}));
928+
m.def("_xla_tensors_from_aten",
929+
[](const std::vector<at::Tensor>& tensors,
930+
const std::vector<std::string>& devices,
931+
const std::optional<std::vector<XLATensor::ShardingSpecPtr>>&
932+
shardings) {
933+
std::vector<at::Tensor> result;
934+
{
935+
NoGilSection nogil;
936+
std::vector<at::Tensor> xla_tensors =
937+
GetXlaTensorsFromAten(tensors, devices, shardings);
938+
result.reserve(xla_tensors.size());
939+
for (size_t i = 0; i < xla_tensors.size(); ++i) {
940+
result.push_back(torch::autograd::make_variable(
941+
xla_tensors[i],
942+
/*requires_grad=*/tensors.at(i).requires_grad()));
943+
}
944+
}
945+
return result;
946+
},
947+
py::arg("tensors"), py::arg("devices"),
948+
py::arg("shardings") = py::none());
922949
m.def("_xla_get_cpu_tensors", [](const std::vector<at::Tensor>& tensors) {
923950
std::vector<at::Tensor> result;
924951
{

torch_xla/experimental/xla_sharding.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
from collections import OrderedDict
3-
from dataclasses import dataclass
3+
from dataclasses import dataclass, field
44
import torch
55
import torch_xla
66
import torch_xla.core.xla_model as xm
@@ -78,7 +78,7 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
7878
Annotates the tensor provided with XLA partition spec. Internally,
7979
it annotates the corresponding XLATensor as sharded for the XLA SpmdPartitioner pass.
8080
Args:
81-
t (Union[torch.Tensor, XLAShardedTensor]): input tensor to be annotated with partition_sepc.
81+
t (Union[torch.Tensor, XLAShardedTensor]): input tensor to be annotated with partition_spec.
8282
8383
mesh (Mesh): describes the logical XLA device topology and the underlying device IDs.
8484
@@ -148,6 +148,29 @@ class ShardingSpec:
148148
mesh: Mesh
149149
partition_spec: Tuple[Union[int, None]]
150150

151+
# Derived fields
152+
_tile_assignment: List[int] = field(init=False)
153+
_replicated: bool = field(init=False)
154+
_partial: bool = field(init=False)
155+
156+
def __post_init__(self):
157+
self._tile_assignment = self.mesh.get_logical_mesh().tolist()
158+
self._replicated = all(d is None for d in self.partition_spec)
159+
self._partial = not self._replicated and any(
160+
d is None for d in self.partition_spec)
161+
# TODO(yeounoh) support partially replicated sharding.
162+
assert not self._partial, "Partial replication is currently not supported"
163+
164+
def xla_spec(self, t: torch.Tensor) -> Union['XlaShardingSpec', None]:
165+
"""
166+
Create an XlaShardingSpec for the given tensor. If the tensor is
167+
incompatible with the ShardingSpec, returns None.
168+
"""
169+
if not self.can_apply(t):
170+
return None
171+
return torch_xla._XLAC.XlaShardingSpec(t, self._tile_assignment,
172+
self._replicated, False)
173+
151174
def can_apply(self, t: torch.Tensor) -> bool:
152175
"""
153176
Test whether the ShardingSpec is compatible with the given torch.Tensor.

0 commit comments

Comments
 (0)