Skip to content

Commit

Permalink
[XLA:GPU] Drop status from HloFusionAnalysis::Create.
Browse files Browse the repository at this point in the history
The only case when we return status is not worth checking, but it add a lot of awkward handling of statuses and optionals. There should not be cases when we can't create an analysis, but can still recover and continue compilation.

PiperOrigin-RevId: 601676192
  • Loading branch information
olegshyshkov authored and tensorflower-gardener committed Jan 26, 2024
1 parent 5f86f9a commit d3b04b8
Show file tree
Hide file tree
Showing 18 changed files with 118 additions and 155 deletions.
4 changes: 2 additions & 2 deletions third_party/xla/xla/service/gpu/buffer_sharing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ std::optional<bool> FusionCanShareBufferHint(const HloInstruction* user,
stream_executor::GpuDeviceInfoProto device_info;
stream_executor::DeviceDescription device_description(device_info);
auto analysis = HloFusionAnalysis::Create(fusion, &device_description);
bool is_reduction_emitter = analysis->GetEmitterFusionKind() ==
bool is_reduction_emitter = analysis.GetEmitterFusionKind() ==
HloFusionAnalysis::EmitterFusionKind::kReduction;
const HloInstruction* reduction_hero =
is_reduction_emitter ? reduction_hero = analysis->FindHeroReduction()
is_reduction_emitter ? reduction_hero = analysis.FindHeroReduction()
: nullptr;

// We need to make sure that the fusion parameter is accessed in the same
Expand Down
3 changes: 1 addition & 2 deletions third_party/xla/xla/service/gpu/fusions/input_slices_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,10 @@ TEST_F(InputSlicesTest, ThreadIndexing) {

auto* root = module->entry_computation()->root_instruction();
auto analysis_fused = AnalyzeFusion(*root, device_info);
ASSERT_NE(analysis_fused, std::nullopt);

TF_ASSERT_OK_AND_ASSIGN(
auto emitter,
GetFusionEmitter(PreBufferAssignmentFusionInfo{*analysis_fused}));
GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused}));
auto fusion = dynamic_cast<InputSlicesFusion*>(emitter.get());
ASSERT_NE(fusion, nullptr);

Expand Down
6 changes: 2 additions & 4 deletions third_party/xla/xla/service/gpu/fusions/loop_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,9 @@ class LoopTest : public HloTestBase {
};

absl::StatusOr<std::unique_ptr<LoopFusion>> GetLoopFusion(
const std::optional<HloFusionAnalysis>& analysis) {
TF_RET_CHECK(analysis != std::nullopt);

const HloFusionAnalysis& analysis) {
TF_ASSIGN_OR_RETURN(
auto emitter, GetFusionEmitter(PreBufferAssignmentFusionInfo{*analysis}));
auto emitter, GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis}));
auto fusion = dynamic_cast<LoopFusion*>(emitter.get());
TF_RET_CHECK(fusion != nullptr);

Expand Down
3 changes: 1 addition & 2 deletions third_party/xla/xla/service/gpu/fusions/scatter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,10 @@ TEST_F(ScatterFusionTest, ScatterFusion) {

auto* root = module->entry_computation()->root_instruction();
auto analysis_fused = AnalyzeFusion(*root, device_info);
ASSERT_NE(analysis_fused, std::nullopt);

TF_ASSERT_OK_AND_ASSIGN(
auto emitter,
GetFusionEmitter(PreBufferAssignmentFusionInfo{*analysis_fused}));
GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused}));
auto scatter_fusion = dynamic_cast<ScatterFusion*>(emitter.get());
ASSERT_NE(scatter_fusion, nullptr);
EXPECT_EQ(scatter_fusion->launch_dimensions().launch_bound(),
Expand Down
6 changes: 2 additions & 4 deletions third_party/xla/xla/service/gpu/fusions/triton_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,10 @@ TEST_F(TritonFusionTest, TritonSoftmaxFusion) {
auto* root = module->entry_computation()->root_instruction();
auto analysis_fused =
AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info);
ASSERT_NE(analysis_fused, std::nullopt);

TF_ASSERT_OK_AND_ASSIGN(
auto emitter_fused,
GetFusionEmitter(PreBufferAssignmentFusionInfo{*analysis_fused}));
GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused}));
auto triton_fusion = dynamic_cast<TritonFusion*>(emitter_fused.get());
ASSERT_NE(triton_fusion, nullptr);
auto launch_dims = triton_fusion->launch_dimensions();
Expand All @@ -84,11 +83,10 @@ TEST_F(TritonFusionTest, TritonSoftmaxFusion) {
EXPECT_EQ(launch_dims->num_threads_per_block(), 32);

auto analysis_consumer = AnalyzeFusion(*root, device_info);
ASSERT_NE(analysis_consumer, std::nullopt);

TF_ASSERT_OK_AND_ASSIGN(
auto emitter_consumer,
GetFusionEmitter(PreBufferAssignmentFusionInfo{*analysis_consumer}));
GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_consumer}));
ASSERT_NE(dynamic_cast<TritonFusion*>(emitter_consumer.get()), nullptr);
}

Expand Down
25 changes: 11 additions & 14 deletions third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ HloFusionAnalysis::HloFusionAnalysis(
input_output_info_(std::move(input_output_info)) {}

// static
absl::StatusOr<HloFusionAnalysis> HloFusionAnalysis::Create(
HloFusionAnalysis HloFusionAnalysis::Create(
FusionBackendConfig backend_config,
std::unique_ptr<HloFusionAdaptor> fusion,
const se::DeviceDescription* device_info) {
Expand Down Expand Up @@ -189,13 +189,14 @@ absl::string_view HloFusionAnalysis::GetEmitterFusionKindString(
}

// static
absl::StatusOr<HloFusionAnalysis> HloFusionAnalysis::Create(
HloFusionAnalysis HloFusionAnalysis::Create(
const HloFusionInstruction* fusion,
const se::DeviceDescription* device_info) {
CHECK(device_info != nullptr);
TF_ASSIGN_OR_RETURN(auto gpu_config,
fusion->backend_config<GpuBackendConfig>());
FusionBackendConfig backend_config = gpu_config.fusion_backend_config();
FusionBackendConfig backend_config =
fusion->has_backend_config()
? fusion->backend_config<GpuBackendConfig>()->fusion_backend_config()
: FusionBackendConfig::default_instance();
return Create(std::move(backend_config),
HloFusionAdaptor::ForInstruction(fusion), device_info);
}
Expand Down Expand Up @@ -321,10 +322,10 @@ const HloInstruction* HloFusionAnalysis::FindHeroReduction() const {
LOG(FATAL) << "Did not find a hero reduction";
}

std::optional<HloFusionAnalysis> AnalyzeProducerConsumerFusion(
HloFusionAnalysis AnalyzeProducerConsumerFusion(
const HloInstruction& producer, const HloInstruction& consumer,
const se::DeviceDescription& device_info) {
auto ret = HloFusionAnalysis::Create(
return HloFusionAnalysis::Create(
consumer.has_backend_config()
? consumer.backend_config<GpuBackendConfig>()->fusion_backend_config()
: producer.backend_config<GpuBackendConfig>()
Expand All @@ -333,17 +334,13 @@ std::optional<HloFusionAnalysis> AnalyzeProducerConsumerFusion(
HloFusionAdaptor::ForInstruction(&producer),
HloFusionAdaptor::ForInstruction(&consumer)),
&device_info);
if (!ret.ok()) return std::nullopt;
return {std::move(*ret)};
}

std::optional<HloFusionAnalysis> AnalyzeFusion(
const HloInstruction& consumer, const se::DeviceDescription& device_info) {
auto ret = HloFusionAnalysis::Create(
HloFusionAnalysis AnalyzeFusion(const HloInstruction& consumer,
const se::DeviceDescription& device_info) {
return HloFusionAnalysis::Create(
consumer.backend_config<GpuBackendConfig>()->fusion_backend_config(),
HloFusionAdaptor::ForInstruction(&consumer), &device_info);
if (!ret.ok()) return std::nullopt;
return {std::move(*ret)};
}

} // namespace gpu
Expand Down
19 changes: 8 additions & 11 deletions third_party/xla/xla/service/gpu/hlo_fusion_analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ limitations under the License.
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/launch_dimensions.h"
#include "xla/statusor.h"
#include "xla/stream_executor/device_description.h"

namespace xla {
Expand Down Expand Up @@ -55,13 +54,11 @@ class HloFusionAnalysis {
int smallest_input_dtype_bits;
};

static absl::StatusOr<HloFusionAnalysis> Create(
FusionBackendConfig backend_config,
std::unique_ptr<HloFusionAdaptor> fusion,
const se::DeviceDescription* device_info);
static absl::StatusOr<HloFusionAnalysis> Create(
const HloFusionInstruction* fusion,
const se::DeviceDescription* device_info);
static HloFusionAnalysis Create(FusionBackendConfig backend_config,
std::unique_ptr<HloFusionAdaptor> fusion,
const se::DeviceDescription* device_info);
static HloFusionAnalysis Create(const HloFusionInstruction* fusion,
const se::DeviceDescription* device_info);

const std::vector<const HloInstruction*>& fusion_roots() const {
return fusion_roots_;
Expand Down Expand Up @@ -118,14 +115,14 @@ class HloFusionAnalysis {

// Creates a HloFusionAnalysis that analyzes a hypothetical fusion of producer
// into consumer.
std::optional<HloFusionAnalysis> AnalyzeProducerConsumerFusion(
HloFusionAnalysis AnalyzeProducerConsumerFusion(
const HloInstruction& producer, const HloInstruction& consumer,
const se::DeviceDescription& device_info);

// Creates a HloFusionAnalysis that analyzes just consumer as a standalone
// fusion.
std::optional<HloFusionAnalysis> AnalyzeFusion(
const HloInstruction& consumer, const se::DeviceDescription& device_info);
HloFusionAnalysis AnalyzeFusion(const HloInstruction& consumer,
const se::DeviceDescription& device_info);

} // namespace gpu
} // namespace xla
Expand Down
49 changes: 19 additions & 30 deletions third_party/xla/xla/service/gpu/hlo_fusion_analysis_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,12 @@ TEST_F(HloFusionAnalysisTest, DoesNotPeekOutsideBoundary) {

auto* root = module->entry_computation()->root_instruction();
auto analysis = AnalyzeFusion(*root, device_info);
ASSERT_NE(analysis, std::nullopt);
EXPECT_EQ(analysis->GetEmitterFusionKind(),
EXPECT_EQ(analysis.GetEmitterFusionKind(),
HloFusionAnalysis::EmitterFusionKind::kLoop);

auto analysis_fused =
AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info);
ASSERT_NE(analysis_fused, std::nullopt);
EXPECT_EQ(analysis_fused->GetEmitterFusionKind(),
EXPECT_EQ(analysis_fused.GetEmitterFusionKind(),
HloFusionAnalysis::EmitterFusionKind::kReduction);
}

Expand Down Expand Up @@ -89,12 +87,11 @@ TEST_F(HloFusionAnalysisTest, ReductionWithMultipleUsers) {

auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo();

TF_ASSERT_OK_AND_ASSIGN(
auto analysis, HloFusionAnalysis::Create(
FusionBackendConfig::default_instance(),
HloFusionAdaptor::ForInstruction(
module->entry_computation()->root_instruction()),
&device_info));
auto analysis = HloFusionAnalysis::Create(
FusionBackendConfig::default_instance(),
HloFusionAdaptor::ForInstruction(
module->entry_computation()->root_instruction()),
&device_info);
EXPECT_EQ(analysis.GetEmitterFusionKind(),
HloFusionAnalysis::EmitterFusionKind::kReduction);
}
Expand Down Expand Up @@ -125,10 +122,9 @@ TEST_F(HloFusionAnalysisTest, ReductionEpilogueFusion) {
auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo();

auto* root = module->entry_computation()->root_instruction();
TF_ASSERT_OK_AND_ASSIGN(
auto analysis, HloFusionAnalysis::Create(
FusionBackendConfig::default_instance(),
HloFusionAdaptor::ForInstruction(root), &device_info));
auto analysis = HloFusionAnalysis::Create(
FusionBackendConfig::default_instance(),
HloFusionAdaptor::ForInstruction(root), &device_info);
EXPECT_EQ(analysis.GetEmitterFusionKind(),
HloFusionAnalysis::EmitterFusionKind::kReduction);
}
Expand Down Expand Up @@ -162,8 +158,7 @@ TEST_F(HloFusionAnalysisTest, ReductionEpilogueFusionPartiallyFused) {

auto analysis =
AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info);
ASSERT_NE(analysis, std::nullopt);
EXPECT_EQ(analysis->GetEmitterFusionKind(),
EXPECT_EQ(analysis.GetEmitterFusionKind(),
HloFusionAnalysis::EmitterFusionKind::kReduction);
}

Expand Down Expand Up @@ -194,8 +189,7 @@ TEST_F(HloFusionAnalysisTest, ReductionEpilogueFusionPartiallyFusedInConsumer) {
auto* root = module->entry_computation()->root_instruction();
auto analysis =
AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info);
ASSERT_NE(analysis, std::nullopt);
EXPECT_EQ(analysis->GetEmitterFusionKind(),
EXPECT_EQ(analysis.GetEmitterFusionKind(),
HloFusionAnalysis::EmitterFusionKind::kReduction);
}

Expand Down Expand Up @@ -232,8 +226,7 @@ TEST_F(HloFusionAnalysisTest, ReductionEpilogueFusionPartiallyFusedInBoth) {
auto* root = module->entry_computation()->root_instruction();
auto analysis =
AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info);
ASSERT_NE(analysis, std::nullopt);
EXPECT_EQ(analysis->GetEmitterFusionKind(),
EXPECT_EQ(analysis.GetEmitterFusionKind(),
HloFusionAnalysis::EmitterFusionKind::kReduction);
}

Expand Down Expand Up @@ -266,8 +259,7 @@ TEST_F(HloFusionAnalysisTest, ReduceMultiOutputFusionWithTransposeBitcast) {
auto* root = module->entry_computation()->root_instruction();
auto analysis =
AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info);
ASSERT_NE(analysis, std::nullopt);
EXPECT_EQ(analysis->GetEmitterFusionKind(),
EXPECT_EQ(analysis.GetEmitterFusionKind(),
HloFusionAnalysis::EmitterFusionKind::kReduction);
}

Expand Down Expand Up @@ -300,10 +292,9 @@ TEST_F(HloFusionAnalysisTest, InvalidReduceMultiOutputFusion) {
auto* root = module->entry_computation()->root_instruction();
auto analysis =
AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info);
ASSERT_NE(analysis, std::nullopt);
// We expect to fallback to the loop emitter, because the two reductions are
// not compatible as they reduce over different dimensions.
EXPECT_EQ(analysis->GetEmitterFusionKind(),
EXPECT_EQ(analysis.GetEmitterFusionKind(),
HloFusionAnalysis::EmitterFusionKind::kLoop);
}

Expand Down Expand Up @@ -333,8 +324,7 @@ TEST_F(HloFusionAnalysisTest, InvalidDevice) {
auto* root = module->entry_computation()->root_instruction();
auto analysis_fused =
AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info);
ASSERT_NE(analysis_fused, std::nullopt);
EXPECT_EQ(analysis_fused->GetEmitterFusionKind(),
EXPECT_EQ(analysis_fused.GetEmitterFusionKind(),
HloFusionAnalysis::EmitterFusionKind::kReduction);
}

Expand All @@ -359,10 +349,9 @@ TEST_F(HloFusionAnalysisTest, ConcatFusion) {
auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo();

auto* root = module->entry_computation()->root_instruction();
TF_ASSERT_OK_AND_ASSIGN(
auto analysis, HloFusionAnalysis::Create(
FusionBackendConfig::default_instance(),
HloFusionAdaptor::ForInstruction(root), &device_info));
auto analysis = HloFusionAnalysis::Create(
FusionBackendConfig::default_instance(),
HloFusionAdaptor::ForInstruction(root), &device_info);
EXPECT_EQ(analysis.GetEmitterFusionKind(),
HloFusionAnalysis::EmitterFusionKind::kConcatenate);
}
Expand Down
9 changes: 3 additions & 6 deletions third_party/xla/xla/service/gpu/ir_emitter_unnested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2539,8 +2539,7 @@ absl::Status IrEmitterUnnested::EmitFusion(
// Create HloFusionAnalysis instance.
const se::DeviceDescription& device_info =
ir_emitter_context_->gpu_device_info();
TF_ASSIGN_OR_RETURN(auto fusion_analysis,
HloFusionAnalysis::Create(fusion, &device_info));
auto fusion_analysis = HloFusionAnalysis::Create(fusion, &device_info);

TF_ASSIGN_OR_RETURN(
std::unique_ptr<FusionInterface> emitter,
Expand Down Expand Up @@ -4149,8 +4148,7 @@ absl::Status IrEmitterUnnested::EmitOp(
Cast<HloFusionInstruction>(hlo_for_lmhlo.at(op));
const se::DeviceDescription& device_info =
ir_emitter_context_->gpu_device_info();
TF_ASSIGN_OR_RETURN(auto fusion_analysis,
HloFusionAnalysis::Create(instr, &device_info));
auto fusion_analysis = HloFusionAnalysis::Create(instr, &device_info);
return EmitFusion(instr, fusion_analysis);
}

Expand Down Expand Up @@ -4470,8 +4468,7 @@ absl::Status IrEmitterUnnested::EmitHloInstruction(
auto* fusion = Cast<HloFusionInstruction>(instr);
const se::DeviceDescription& device_info =
ir_emitter_context_->gpu_device_info();
TF_ASSIGN_OR_RETURN(auto fusion_analysis,
HloFusionAnalysis::Create(fusion, &device_info));
auto fusion_analysis = HloFusionAnalysis::Create(fusion, &device_info);
return EmitFusion(fusion, fusion_analysis);
}
case HloOpcode::kInfeed:
Expand Down
16 changes: 6 additions & 10 deletions third_party/xla/xla/service/gpu/model/coalescing_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,13 @@ namespace gpu {

using mlir::AffineMap;

bool IsReadCoalescedHeuristic(
const std::optional<HloFusionAnalysis>& fusion_analysis,
const HloInstruction* producer, const HloInstruction* consumer) {
auto analyzed_kind_or_reduction =
fusion_analysis ? fusion_analysis->GetEmitterFusionKind()
: HloFusionAnalysis::EmitterFusionKind::kReduction;
bool IsReadCoalescedHeuristic(const HloFusionAnalysis& fusion_analysis,
const HloInstruction* producer,
const HloInstruction* consumer) {
auto fusion_kind = fusion_analysis.GetEmitterFusionKind();

// Transposing minor dimension breaks coalescing.
if (analyzed_kind_or_reduction !=
HloFusionAnalysis::EmitterFusionKind::kTranspose) {
if (fusion_kind != HloFusionAnalysis::EmitterFusionKind::kTranspose) {
auto is_broadcast = [&](const HloInstruction* instr) {
while (true) {
if (instr->opcode() == HloOpcode::kBroadcast) return true;
Expand Down Expand Up @@ -77,8 +74,7 @@ bool IsReadCoalescedHeuristic(
}

// Fusing two row reductions breaks coalescing.
if (analyzed_kind_or_reduction ==
HloFusionAnalysis::EmitterFusionKind::kReduction &&
if (fusion_kind == HloFusionAnalysis::EmitterFusionKind::kReduction &&
IsInputFusibleReduction(*producer) && consumer &&
IsInputFusibleReduction(*consumer)) {
return false;
Expand Down
6 changes: 3 additions & 3 deletions third_party/xla/xla/service/gpu/model/coalescing_analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ namespace gpu {
// Returns true if all input reads are coalesced. If consumer is not nullptr,
// producer and consumer are considered as one fusion, otherwise it's only the
// producer.
bool IsReadCoalescedHeuristic(
const std::optional<HloFusionAnalysis>& fusion_analysis,
const HloInstruction* producer, const HloInstruction* consumer = nullptr);
bool IsReadCoalescedHeuristic(const HloFusionAnalysis& fusion_analysis,
const HloInstruction* producer,
const HloInstruction* consumer = nullptr);

// Returns true, if operand's read is coalesced.
bool IsReadCoalesced(const HloInstruction* operand, const HloInstruction* instr,
Expand Down
Loading

0 comments on commit d3b04b8

Please sign in to comment.