Skip to content

Commit

Permalink
[XLA:GPU] Refactor HloFusionAdaptor to a flat structure.
Browse files Browse the repository at this point in the history
There is a need for HloInstructionAdaptor to know it's parent HloFusionAdaptor to be able to distinguish fusion instruction for operands and users. Current nested structure of HloFusionAdaptor prevents us from having a convenient single parent pointer.

This CL add `parent` field to HloInstructionAdaptor, but doesn't set or use it consistently yet.

PiperOrigin-RevId: 627556942
  • Loading branch information
olegshyshkov authored and tensorflower-gardener committed Apr 24, 2024
1 parent 0d6a9e6 commit 49837f3
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 126 deletions.
4 changes: 1 addition & 3 deletions third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc
Expand Up @@ -306,9 +306,7 @@ HloFusionAnalysis AnalyzeProducerConsumerFusion(
? consumer.backend_config<GpuBackendConfig>()->fusion_backend_config()
: producer.backend_config<GpuBackendConfig>()
->fusion_backend_config(),
std::make_unique<ProducerConsumerFusion>(
HloFusionAdaptor::ForInstruction(&producer),
HloFusionAdaptor::ForInstruction(&consumer)),
HloFusionAdaptor::ForProducerConsumer(&producer, &consumer),
&device_info);
}

Expand Down
105 changes: 87 additions & 18 deletions third_party/xla/xla/service/gpu/hlo_traversal.cc
Expand Up @@ -15,11 +15,14 @@ limitations under the License.
#include "xla/service/gpu/hlo_traversal.h"

#include <functional>
#include <iterator>
#include <memory>
#include <optional>
#include <queue>
#include <sstream>
#include <string>

#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/log/check.h"
Expand Down Expand Up @@ -79,38 +82,42 @@ const HloInstruction* ResolveOperand(const HloInstruction* operand) {
}
return operand;
}
} // namespace

class SingleInstructionFusion : public HloFusionAdaptor {
class SingleInstructionFusion : public internal::HloFusionInstructionAdaptor {
public:
explicit SingleInstructionFusion(const HloInstruction* instruction)
: instruction_(*instruction) {
explicit SingleInstructionFusion(const HloInstruction* instruction,
const HloFusionAdaptor* parent)
: instruction_(instruction), parent_(parent) {
CHECK_NE(instruction->opcode(), HloOpcode::kFusion)
<< "Use HloFusionFusion";
}

bool ContainsInstruction(HloInstructionAdaptor instruction) const override {
return instruction == instruction_;
return &instruction.instruction() == instruction_;
}

absl::InlinedVector<HloInstructionAdaptor, 2> GetRoots() const override {
return {instruction_};
return {HloInstructionAdaptor{*instruction_, parent_}};
}

absl::InlinedVector<HloInstructionAdaptor, 2> MakeInstructionPostOrder()
const override {
return {instruction_};
return {HloInstructionAdaptor{*instruction_, parent_}};
}

std::string ToString() const override { return instruction_.ToString(); }
std::string ToString() const override { return instruction_->ToString(); }

private:
HloInstructionAdaptor instruction_;
const HloInstruction* instruction_;
const HloFusionAdaptor* parent_;
};

class HloComputationFusion : public HloFusionAdaptor {
class HloComputationFusion : public internal::HloFusionInstructionAdaptor {
public:
explicit HloComputationFusion(const HloComputation* computation)
: computation_(computation) {
explicit HloComputationFusion(const HloComputation* computation,
const HloFusionAdaptor* parent)
: computation_(computation), parent_(parent) {
// HloFusionAdaptor should only be created for fusion computations, that
// usually have only a few roots, but there is a case when we can it for
// non-fusion computations with thousands of roots. It happens inside
Expand All @@ -128,7 +135,7 @@ class HloComputationFusion : public HloFusionAdaptor {
}
}

static absl::InlinedVector<HloInstructionAdaptor, 2> FindRoots(
absl::InlinedVector<HloInstructionAdaptor, 2> FindRoots(
const HloComputation* computation) {
absl::InlinedVector<HloInstructionAdaptor, 2> roots;

Expand All @@ -140,7 +147,7 @@ class HloComputationFusion : public HloFusionAdaptor {
get_roots(operand);
}
} else {
HloInstructionAdaptor wrapped{*instr};
HloInstructionAdaptor wrapped{*instr, parent_};
if (roots_set.insert(wrapped).second) {
roots.push_back(wrapped);
}
Expand Down Expand Up @@ -181,7 +188,7 @@ class HloComputationFusion : public HloFusionAdaptor {
(instr->opcode() == HloOpcode::kTuple && instr->IsRoot())) {
continue;
}
result.emplace_back(*instr);
result.emplace_back(*instr, parent_);
}
return result;
}
Expand All @@ -191,21 +198,83 @@ class HloComputationFusion : public HloFusionAdaptor {
private:
const HloComputation* computation_;
absl::InlinedVector<HloInstructionAdaptor, 2> roots_;
const HloFusionAdaptor* parent_;
};

} // namespace

/*static*/
std::unique_ptr<HloFusionAdaptor> HloFusionAdaptor::ForInstruction(
const HloInstruction* instruction) {
if (instruction->opcode() == HloOpcode::kFusion) {
return ForComputation(instruction->fused_instructions_computation());
}
return std::make_unique<SingleInstructionFusion>(instruction);

auto fusion_adaptor = std::make_unique<HloFusionAdaptor>();
fusion_adaptor->AddInstruction(instruction);
return fusion_adaptor;
}

/*static*/
std::unique_ptr<HloFusionAdaptor> HloFusionAdaptor::ForProducerConsumer(
const HloInstruction* producer, const HloInstruction* consumer) {
auto fusion_adaptor = std::make_unique<HloFusionAdaptor>();
fusion_adaptor->AddInstruction(producer);
fusion_adaptor->AddInstruction(consumer);
return fusion_adaptor;
}

/*static*/
std::unique_ptr<HloFusionAdaptor> HloFusionAdaptor::ForComputation(
const HloComputation* computation) {
return std::make_unique<HloComputationFusion>(computation);
auto fusion_adaptor = std::make_unique<HloFusionAdaptor>();
fusion_adaptor->AddComputation(computation);
return fusion_adaptor;
}

bool HloFusionAdaptor::ContainsInstruction(
HloInstructionAdaptor instruction) const {
for (const auto& fusion_instruction : fusion_instructions_) {
if (fusion_instruction->ContainsInstruction(instruction)) return true;
}
return false;
}

absl::InlinedVector<HloInstructionAdaptor, 2> HloFusionAdaptor::GetRoots()
const {
return fusion_instructions_.back()->GetRoots();
}

absl::InlinedVector<HloInstructionAdaptor, 2>
HloFusionAdaptor::MakeInstructionPostOrder() const {
absl::InlinedVector<HloInstructionAdaptor, 2> result_post_order;

for (const auto& fusion_instruction : fusion_instructions_) {
absl::c_move(fusion_instruction->MakeInstructionPostOrder(),
std::back_inserter(result_post_order));
}

return result_post_order;
}

std::string HloFusionAdaptor::ToString() const {
std::ostringstream ss;
for (const auto& fusion_instruction : fusion_instructions_) {
ss << fusion_instruction->ToString() << "\n";
}
return ss.str();
}

void HloFusionAdaptor::AddInstruction(const HloInstruction* instruction) {
if (instruction->opcode() == HloOpcode::kFusion) {
AddComputation(instruction->fused_instructions_computation());
} else {
fusion_instructions_.push_back(
std::make_unique<SingleInstructionFusion>(instruction, this));
}
}

void HloFusionAdaptor::AddComputation(const HloComputation* computation) {
fusion_instructions_.push_back(
std::make_unique<HloComputationFusion>(computation, this));
}

absl::InlinedVector<HloInstructionAdaptor, 2>
Expand Down
88 changes: 36 additions & 52 deletions third_party/xla/xla/service/gpu/hlo_traversal.h
Expand Up @@ -16,15 +16,12 @@ limitations under the License.
#define XLA_SERVICE_GPU_HLO_TRAVERSAL_H_

#include <functional>
#include <iterator>
#include <memory>
#include <optional>
#include <string>
#include <utility>

#include "absl/algorithm/container.h"
#include "absl/container/inlined_vector.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/hlo/ir/hlo_instruction.h"
Expand All @@ -34,12 +31,15 @@ limitations under the License.
namespace xla {
namespace gpu {

class HloFusionAdaptor;

// Treats HloInstructions as if they were unfused.
class HloInstructionAdaptor {
public:
HloInstructionAdaptor() = default;
explicit HloInstructionAdaptor(const HloInstruction& instruction)
: instruction_(&instruction) {}
explicit HloInstructionAdaptor(const HloInstruction& instruction,
const HloFusionAdaptor* parent = nullptr)
: instruction_(&instruction), parent_(parent) {}

HloOpcode opcode() const { return instruction_->opcode(); }
absl::string_view name() const { return instruction_->name(); }
Expand All @@ -60,6 +60,13 @@ class HloInstructionAdaptor {

private:
const HloInstruction* instruction_;

// Pointer to the parent fusion adaptor. Can be null for legacy cases when
// HloInstructionAdaptor is used without HloFusionAdaptor.
// TODO(shyshkov): Consistently set parent pointer in all cases and check that
// it is not null.
// TODO(shyshkov): Use parent to determine operands and users correctly.
const HloFusionAdaptor* parent_;
};

template <typename H>
Expand All @@ -73,66 +80,43 @@ bool IsOpcodeAnyOf(const HloInstructionAdaptor& adaptor) {
return (adaptor.opcode() == op) || ((adaptor.opcode() == rest) || ...);
}

class HloFusionAdaptor {
namespace internal {

// An interface to abstract away the difference between single instruction
// fusion and fused computations.
class HloFusionInstructionAdaptor {
public:
virtual ~HloFusionAdaptor() = default;
virtual ~HloFusionInstructionAdaptor() = default;
virtual bool ContainsInstruction(HloInstructionAdaptor instruction) const = 0;
virtual absl::InlinedVector<HloInstructionAdaptor, 2> GetRoots() const = 0;
virtual absl::InlinedVector<HloInstructionAdaptor, 2>
MakeInstructionPostOrder() const = 0;
virtual std::string ToString() const = 0;

static std::unique_ptr<HloFusionAdaptor> ForInstruction(
const HloInstruction* instruction);
static std::unique_ptr<HloFusionAdaptor> ForComputation(
const HloComputation* computation);
};

class ProducerConsumerFusion : public HloFusionAdaptor {
public:
ProducerConsumerFusion(std::unique_ptr<HloFusionAdaptor> producer,
std::unique_ptr<HloFusionAdaptor> consumer)
: producer_(std::move(producer)), consumer_(std::move(consumer)) {}

ProducerConsumerFusion(const HloInstruction* producer,
const HloInstruction* consumer)
: ProducerConsumerFusion(HloFusionAdaptor::ForInstruction(producer),
HloFusionAdaptor::ForInstruction(consumer)) {}

bool ContainsInstruction(HloInstructionAdaptor instruction) const override {
return producer_->ContainsInstruction(instruction) ||
consumer_->ContainsInstruction(instruction);
}

absl::InlinedVector<HloInstructionAdaptor, 2> GetRoots() const override {
return consumer_->GetRoots();
}
} // namespace internal

class HloFusionAdaptor {
public:
bool ContainsInstruction(HloInstructionAdaptor instruction) const;
absl::InlinedVector<HloInstructionAdaptor, 2> GetRoots() const;
absl::InlinedVector<HloInstructionAdaptor, 2> MakeInstructionPostOrder()
const override {
auto producer_post_order = producer_->MakeInstructionPostOrder();
auto consumer_post_order = consumer_->MakeInstructionPostOrder();

producer_post_order.reserve(consumer_post_order.size() +
producer_post_order.size());

absl::c_move(consumer_post_order, std::back_inserter(producer_post_order));

return producer_post_order;
}
const;
std::string ToString() const;

std::string ToString() const override {
// TODO: Add a parameter to indent output on nested adaptor for better
// visual representation. Nested producer-consumers fusion are not used in
// practice yet.
return absl::StrJoin({std::string("producer-consumer fusion:"),
producer_->ToString(), consumer_->ToString()},
"\n");
}
static std::unique_ptr<HloFusionAdaptor> ForInstruction(
const HloInstruction* instruction);
static std::unique_ptr<HloFusionAdaptor> ForProducerConsumer(
const HloInstruction* producer, const HloInstruction* consumer);
static std::unique_ptr<HloFusionAdaptor> ForComputation(
const HloComputation* computation);

private:
std::unique_ptr<HloFusionAdaptor> producer_;
std::unique_ptr<HloFusionAdaptor> consumer_;
void AddInstruction(const HloInstruction* instruction);
void AddComputation(const HloComputation* computation);

absl::InlinedVector<std::unique_ptr<internal::HloFusionInstructionAdaptor>, 2>
fusion_instructions_;
};

enum class TraversalResult {
Expand Down

0 comments on commit 49837f3

Please sign in to comment.