Skip to content

Commit

Permalink
[XLA:GPU] Make "collectives-schedule-linearizer" a last optimisation …
Browse files Browse the repository at this point in the history
…pass.

PiperOrigin-RevId: 636059760
  • Loading branch information
golechwierowicz authored and tensorflower-gardener committed May 22, 2024
1 parent e5592ac commit eb95654
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
19 changes: 11 additions & 8 deletions third_party/xla/xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1265,7 +1265,7 @@ absl::Status GpuCompiler::OptimizeHloModule(
TF_RETURN_IF_ERROR(RunPostFusionVerificationPasses(
hlo_module, stream_exec, options, gpu_target_config));

return absl::OkStatus();
return RunPreSchedulingPasses(hlo_module, stream_exec);
} // NOLINT(readability/fn_size)

AlgebraicSimplifierOptions GpuCompiler::GetAlgebraicSimplifierOptions(
Expand Down Expand Up @@ -1426,13 +1426,6 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
.VerifyReshapeIsBitcast(),
/*debug_only=*/true);

// Linearize collective schedule if online autotuning of convolutions is
// enabled.
pipeline.AddPass<CollectivesScheduleLinearizer>(
[this, stream_exec](const HloModule* module) {
return RequiresCollectiveScheduleLinearizer(module, stream_exec);
});

// Triton compilation needs normalized operations on bf16 (i.e. converted to
// f32).
add_float_normalization(pipeline);
Expand Down Expand Up @@ -2175,6 +2168,16 @@ absl::StatusOr<std::unique_ptr<AotCompilationResult>> GpuCompiler::Export(
gpu_executable->dnn_compiled_graphs());
}

absl::Status GpuCompiler::RunPreSchedulingPasses(
HloModule* module, se::StreamExecutor* stream_exec) {
HloPassPipeline pipeline("pre-scheduling-passes");
pipeline.AddPass<CollectivesScheduleLinearizer>(
[this, stream_exec](const HloModule* module) {
return RequiresCollectiveScheduleLinearizer(module, stream_exec);
});
return pipeline.Run(module).status();
}

absl::Status GpuCompiler::RunPostSchedulingPipelines(
HloModule* module, int64_t scheduler_mem_limit,
const se::DeviceDescription& gpu_device_info) const {
Expand Down
3 changes: 3 additions & 0 deletions third_party/xla/xla/service/gpu/gpu_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ class GpuCompiler : public LLVMCompiler {
absl::Status SerializeAutotuneResultsToFile(
const DebugOptions& debug_options);

absl::Status RunPreSchedulingPasses(HloModule* module,
se::StreamExecutor* stream_exec);

// During compilation with device, stream_exec != null and autotune_results
// == null. During deviceless AOT compilation, stream_exec == null and
// autotune_results != null.
Expand Down

0 comments on commit eb95654

Please sign in to comment.