Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:GPU] Add TiledHloInstruction. #65006

Merged
merged 1 commit into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
33 changes: 33 additions & 0 deletions third_party/xla/xla/service/gpu/model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,39 @@ xla_cc_test(
],
)

cc_library(
name = "tiled_hlo_instruction",
srcs = ["tiled_hlo_instruction.cc"],
hdrs = ["tiled_hlo_instruction.h"],
deps = [
":indexing_map",
"//xla:util",
"//xla/hlo/ir:hlo",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
)

xla_cc_test(
name = "tiled_hlo_instruction_test",
srcs = ["tiled_hlo_instruction_test.cc"],
deps = [
":indexing_map",
":indexing_test_utils",
":tiled_hlo_instruction",
"//xla:shape_util",
"//xla/hlo/ir:hlo",
"//xla/tests:hlo_test_base",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_googletest//:gtest_main",
"@llvm-project//mlir:IR",
"@local_tsl//tsl/platform:statusor",
],
)

cc_library(
name = "symbolic_tile_analysis",
srcs = ["symbolic_tile_analysis.cc"],
Expand Down
118 changes: 118 additions & 0 deletions third_party/xla/xla/service/gpu/model/tiled_hlo_instruction.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/service/gpu/model/tiled_hlo_instruction.h"

#include <cstddef>
#include <cstdint>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>

#include "absl/hash/hash.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/service/gpu/model/indexing_map.h"
#include "xla/util.h"

namespace xla {
namespace gpu {

size_t TiledHloInstruction::PtrHash::operator()(
const TiledHloInstruction* tiled_hlo) const {
return absl::HashOf(*tiled_hlo);
}

bool TiledHloInstruction::PtrEqual::operator()(
const TiledHloInstruction* lhs, const TiledHloInstruction* rhs) const {
return *lhs == *rhs;
}

bool operator==(const TiledHloInstruction& lhs,
const TiledHloInstruction& rhs) {
return lhs.hlo() == rhs.hlo() && lhs.tile_sizes() == rhs.tile_sizes() &&
lhs.tile_strides() == rhs.tile_strides() &&
lhs.block_id_to_tile_offsets_indexing() ==
rhs.block_id_to_tile_offsets_indexing();
}

bool operator!=(const TiledHloInstruction& lhs,
const TiledHloInstruction& rhs) {
return !(lhs == rhs);
}

/*static*/
absl::StatusOr<std::unique_ptr<TiledHloInstruction>>
TiledHloInstruction::Create(const HloInstruction* hlo,
std::vector<int64_t> tile_sizes,
std::vector<int64_t> tile_strides,
IndexingMap block_id_to_tile_offsets_indexing) {
int rank = hlo->shape().rank();

if (tile_sizes.size() != rank) {
return absl::InvalidArgumentError(
absl::StrCat("Number of tile sizes must be equal to the rank of the "
"hlo shape. tile_sizes = ",
tile_sizes.size(), ", hlo = ", hlo->ToString()));
}

if (tile_strides.size() != rank) {
return absl::InvalidArgumentError(
absl::StrCat("Number of tile strides must be equal to the rank of the "
"hlo shape. tile_sizes = ",
tile_strides.size(), ", hlo = ", hlo->ToString()));
}

if (block_id_to_tile_offsets_indexing.GetDimensionCount() != 1 ||
block_id_to_tile_offsets_indexing.GetSymbolCount() != 0) {
return absl::InvalidArgumentError(absl::StrCat(
"block_id_to_tile_offsets_indexing must have 1 dim and 0 symbols. "
"block_id_to_tile_offsets_indexing = ",
block_id_to_tile_offsets_indexing.ToString()));
}

if (block_id_to_tile_offsets_indexing.GetAffineMap().getNumResults() !=
rank) {
return absl::InvalidArgumentError(absl::StrCat(
"block_id_to_tile_offsets_indexing must have the same number of "
"results as the rank of the hlo shape. "
"block_id_to_tile_offsets_indexing = ",
block_id_to_tile_offsets_indexing.ToString(),
", hlo = ", hlo->ToString()));
}

return absl::WrapUnique(new TiledHloInstruction(
hlo, std::move(tile_sizes), std::move(tile_strides),
std::move(block_id_to_tile_offsets_indexing)));
}

std::string TiledHloInstruction::ToString() const {
std::stringstream ss;
ss << "hlo: " << hlo_->ToString() << "\n";
ss << "tile_sizes: {" << absl::StrJoin(tile_sizes_, ", ") << "}\n";
ss << "tile_strides: {" << absl::StrJoin(tile_strides_, ", ") << "}\n";
ss << "block_id_to_tile_offsets_indexing: "
<< block_id_to_tile_offsets_indexing_;
return ss.str();
}

} // namespace gpu
} // namespace xla
136 changes: 136 additions & 0 deletions third_party/xla/xla/service/gpu/model/tiled_hlo_instruction.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_SERVICE_GPU_MODEL_TILED_HLO_INSTRUCTION_H_
#define XLA_SERVICE_GPU_MODEL_TILED_HLO_INSTRUCTION_H_

#include <cstddef>
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "absl/status/statusor.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/service/gpu/model/indexing_map.h"

namespace xla {
namespace gpu {

// A wrapper around HloInstruction that represents a tiled HLO instruction.
//
// The class contains information required to emit this instruction in
// block-level codegen. Tile sizes and strides are constants and do not depend
// on the block id. Tile offsets are computed using an indexing map of form:
// `(block_id) -> (tile_offset0, tile_offset1, ...)`.
class TiledHloInstruction {
public:
// PtrHash and PtrEqual are helper classes to use in hash maps and sets that
// compare values behind the pointers. For example,
// absl::flat_hash_set<TiledHloInstruction*, PtrHash, PtrEqual> hlo_set;
struct PtrHash {
size_t operator()(const TiledHloInstruction* tiled_hlo) const;
};

struct PtrEqual {
bool operator()(const TiledHloInstruction* lhs,
const TiledHloInstruction* rhs) const;
};

// Creates an instance of TiledHloInstruction. Returns an error if any of the
// following preconditions is not met:
// * Number of tile sizes, strides should match HLO shape rank.
// * Number of result of `block_id_to_tile_offsets_indexing` should match HLO
// shape rank.
// * `block_id_to_tile_offsets_indexing` should have only 1 dimension and 0
// symbols.
static absl::StatusOr<std::unique_ptr<TiledHloInstruction>> Create(
const HloInstruction* hlo, std::vector<int64_t> tile_sizes,
std::vector<int64_t> tile_strides,
IndexingMap block_id_to_tile_offsets_indexing);

// Returns the original HLO instruction.
const HloInstruction* hlo() const { return hlo_; }

// Returns the tile sizes. The number of tile sizes is equal to the rank of
// the output shape.
const std::vector<int64_t>& tile_sizes() const { return tile_sizes_; }

// Returns the tile strides. The number of tile strides is equal to the rank
// of the output shape.
const std::vector<int64_t>& tile_strides() const { return tile_strides_; }

// Returns the indexing map from block_id to tile offsets. The map has a form
// of `(block_id) -> (tile_offset0, tile_offset1, ...)`. The number of tile
// offsets is equal to the rank of the output shape.
const IndexingMap& block_id_to_tile_offsets_indexing() const {
return block_id_to_tile_offsets_indexing_;
}

const TiledHloInstruction* operand(int64_t operand_id) const {
return operands_[operand_id];
}

const std::vector<TiledHloInstruction*>& operands() const {
return operands_;
}

void AppendOperand(TiledHloInstruction* operand) {
operands_.push_back(operand);
}

std::string ToString() const;

private:
TiledHloInstruction(const HloInstruction* hlo,
std::vector<int64_t> tile_sizes,
std::vector<int64_t> tile_strides,
IndexingMap block_id_to_tile_offsets_indexing)
: hlo_(hlo),
tile_sizes_(std::move(tile_sizes)),
tile_strides_(std::move(tile_strides)),
block_id_to_tile_offsets_indexing_(
std::move(block_id_to_tile_offsets_indexing)) {}

// Pointer to the original HLO instruction.
const HloInstruction* hlo_;

// Tile sizes and strides.
std::vector<int64_t> tile_sizes_;
std::vector<int64_t> tile_strides_;

// Indexing map from block_id to tile offsets.
IndexingMap block_id_to_tile_offsets_indexing_;

// Operands of the instruction in the tiled computation graph.
std::vector<TiledHloInstruction*> operands_;
};

bool operator==(const TiledHloInstruction& lhs, const TiledHloInstruction& rhs);
bool operator!=(const TiledHloInstruction& lhs, const TiledHloInstruction& rhs);

template <typename H>
H AbslHashValue(H h, const TiledHloInstruction& tiled_hlo_instruction) {
return H::combine(std::move(h), tiled_hlo_instruction.hlo(),
tiled_hlo_instruction.tile_sizes(),
tiled_hlo_instruction.tile_strides(),
tiled_hlo_instruction.block_id_to_tile_offsets_indexing());
}

} // namespace gpu
} // namespace xla

#endif // XLA_SERVICE_GPU_MODEL_TILED_HLO_INSTRUCTION_H_