@@ -410,13 +410,25 @@ std::ptrdiff_t GetTensorId(const at::Tensor& tensor) {
410410
411411std::vector<at::Tensor> GetXlaTensorsFromAten (
412412 const std::vector<at::Tensor>& aten_tensors,
413- const std::vector<std::string>& devices) {
414- auto data_handles = CreateTensorsData (aten_tensors, GetXlaDevices (devices));
413+ const std::vector<std::string>& devices,
414+ const std::optional<std::vector<XLATensor::ShardingSpecPtr>>
415+ sharding_specs) {
416+ std::vector<std::shared_ptr<torch::lazy::BackendData>> data_handles;
417+ if (sharding_specs.has_value ()) {
418+ data_handles = CreateTensorsData (aten_tensors, sharding_specs.value (),
419+ GetXlaDevices (devices));
420+ } else {
421+ data_handles = CreateTensorsData (aten_tensors, GetXlaDevices (devices));
422+ }
415423
416424 std::vector<at::Tensor> xla_tensors;
417425 xla_tensors.reserve (data_handles.size ());
418- for (auto & data_handle : data_handles) {
426+ for (int i = 0 ; i < data_handles.size (); i++) {
427+ auto & data_handle = data_handles[i];
419428 XLATensorPtr xla_tensor = XLATensor::Create (std::move (data_handle));
429+ if (sharding_specs.has_value () && sharding_specs.value ()[i] != nullptr ) {
430+ xla_tensor->SetShardingSpec (*sharding_specs.value ()[i]);
431+ }
420432 xla_tensors.push_back (bridge::AtenFromXlaTensor (std::move (xla_tensor)));
421433 }
422434 return xla_tensors;
@@ -904,21 +916,36 @@ void InitXlaModuleBindings(py::module m) {
904916 [](const std::vector<at::Tensor>& tensors) -> std::string {
905917 return GetTensorsHloGraph (tensors);
906918 });
907- m.def (" _xla_tensors_from_aten" , [](const std::vector<at::Tensor>& tensors,
908- const std::vector<std::string>& devices) {
909- std::vector<at::Tensor> result;
910- {
911- NoGilSection nogil;
912- std::vector<at::Tensor> xla_tensors =
913- GetXlaTensorsFromAten (tensors, devices);
914- result.reserve (xla_tensors.size ());
915- for (size_t i = 0 ; i < xla_tensors.size (); ++i) {
916- result.push_back (torch::autograd::make_variable (
917- xla_tensors[i], /* requires_grad=*/ tensors.at (i).requires_grad ()));
918- }
919- }
920- return result;
921- });
919+ py::class_<XLATensor::ShardingSpec, XLATensor::ShardingSpecPtr>(
920+ m, " XlaShardingSpec" )
921+ .def (py::init ([](at::Tensor tensor, py::list& tile_assignment,
922+ bool replicated, bool manual) {
923+ auto op_sharding =
924+ ShardingUtil::CreateOpSharding (tile_assignment, replicated, manual);
925+ auto shape = CreateComputationShapeFromTensor (tensor, nullptr );
926+ return std::make_shared<XLATensor::ShardingSpec>(op_sharding, shape);
927+ }));
928+ m.def (" _xla_tensors_from_aten" ,
929+ [](const std::vector<at::Tensor>& tensors,
930+ const std::vector<std::string>& devices,
931+ const std::optional<std::vector<XLATensor::ShardingSpecPtr>>&
932+ shardings) {
933+ std::vector<at::Tensor> result;
934+ {
935+ NoGilSection nogil;
936+ std::vector<at::Tensor> xla_tensors =
937+ GetXlaTensorsFromAten (tensors, devices, shardings);
938+ result.reserve (xla_tensors.size ());
939+ for (size_t i = 0 ; i < xla_tensors.size (); ++i) {
940+ result.push_back (torch::autograd::make_variable (
941+ xla_tensors[i],
942+ /* requires_grad=*/ tensors.at (i).requires_grad ()));
943+ }
944+ }
945+ return result;
946+ },
947+ py::arg (" tensors" ), py::arg (" devices" ),
948+ py::arg (" shardings" ) = py::none ());
922949 m.def (" _xla_get_cpu_tensors" , [](const std::vector<at::Tensor>& tensors) {
923950 std::vector<at::Tensor> result;
924951 {
0 commit comments