Skip to content

Commit

Permalink
PR #11354: Passthrough sharding propagation for host offloading ops
Browse files Browse the repository at this point in the history
Imported from GitHub PR openxla/xla#11354

To support host memory offloading we propagate sharding on the offloading operations (MoveToDevice and MoveToHost custom calls) so that spmd partitioner does not insert collectives at wrong places.
Copybara import of the project:

--
6a4d73c56e618e921fa32e5e03756e40986a8ad3 by Jaroslav Sevcik <jsevcik@nvidia.com>:

Passthrough sharding propagation for host offloading ops

Merging this change closes #11354

PiperOrigin-RevId: 626948525
  • Loading branch information
jaro-sevcik authored and tensorflower-gardener committed Apr 22, 2024
1 parent e3f34e6 commit 109956f
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 3 deletions.
1 change: 1 addition & 0 deletions third_party/xla/xla/service/BUILD
Expand Up @@ -638,6 +638,7 @@ cc_library(
":dot_as_convolution_util",
":hlo_graph_dumper",
":hlo_pass",
":host_memory_offload_annotations_hdr",
"//xla:array",
"//xla:protobuf_util",
"//xla:shape_tree",
Expand Down
10 changes: 7 additions & 3 deletions third_party/xla/xla/service/sharding_propagation.cc
Expand Up @@ -44,6 +44,7 @@ limitations under the License.
#include "xla/hlo/utils/hlo_sharding_util.h"
#include "xla/protobuf_util.h"
#include "xla/service/dot_as_convolution_util.h"
#include "xla/service/host_memory_offload_annotations.h"
#include "xla/shape.h"
#include "xla/shape_tree.h"
#include "xla/shape_util.h"
Expand Down Expand Up @@ -277,9 +278,12 @@ bool IsPassthroughCustomOps(const HloInstruction* hlo) {
hlo->operand(0)->shape().rank() != hlo->shape().rank()) {
return false;
}
return hlo->IsCustomCall({"ResizeNearest", "ResizeBilinear",
"ResizeNearestGrad", "ResizeBilinearGrad",
"Cholesky"});

return hlo->IsCustomCall(
{"ResizeNearest", "ResizeBilinear", "ResizeNearestGrad",
"ResizeBilinearGrad", "Cholesky",
host_memory_offload_annotations::kMoveToDeviceCustomCallTarget,
host_memory_offload_annotations::kMoveToHostCustomCallTarget});
}

// Return the operand which is the most suitable for determining the sharding
Expand Down
32 changes: 32 additions & 0 deletions third_party/xla/xla/service/sharding_propagation_test.cc
Expand Up @@ -9374,6 +9374,38 @@ ENTRY %reshape {
EXPECT_THAT(instruction, op::Sharding("{devices=[1,2,2]0,1,2,3}"));
}

TEST_F(ShardingPropagationTest, OffloadingPropagation) {
const char* const hlo_string = R"(
HloModule module
ENTRY %offloading {
%param0 = f32[1,256,128] parameter(0), sharding={devices=[1,1,4]0,1,2,3}
%zero = f32[] constant(0.0)
%broadcast = f32[256,256,128] broadcast(%zero), dimensions={}
%izero = s32[] constant(0)
%custom-call.0 = f32[1,256,128] custom-call(f32[1,256,128] %param0), custom_call_target="MoveToHost"
%dynamic-update-slice = f32[256,256,128] dynamic-update-slice(%broadcast, %custom-call.0, %izero, %izero, %izero)
%dynamic-slice = f32[1,256,128] dynamic-slice(%dynamic-update-slice, %izero, %izero, %izero), dynamic_slice_sizes={1,256,128}
%custom-call.1 = f32[1,256,128] custom-call(f32[1,256,128] %dynamic-slice), custom_call_target="MoveToDevice"
ROOT %copy = f32[1,256,128] copy(%custom-call.1), sharding={devices=[1,4,1]0,1,2,3}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(
bool changed,
ShardingPropagation(/*is_spmd=*/true, /*propagate_metadata=*/true)
.Run(module.get()));

XLA_VLOG_LINES(1, module->ToString());
EXPECT_TRUE(changed);

auto* to_host = FindInstruction(module.get(), "custom-call.0");
EXPECT_THAT(to_host, op::Sharding("{devices=[1,1,4]0,1,2,3}"));

auto* from_host_input =
FindInstruction(module.get(), "custom-call.1")->operand(0);
EXPECT_THAT(from_host_input, op::Sharding("{devices=[1,1,4]0,1,2,3}"));
}

TEST_P(ParameterizedMetadataTest, PropagateThroughSingleUsers) {
const char* const hlo_string = R"(
HloModule module
Expand Down

0 comments on commit 109956f

Please sign in to comment.