Skip to content

Commit

Permalink
[XLA:GPU] Add TiledHloInstruction.
Browse files Browse the repository at this point in the history
A graph of TiledHloInstruction represents an HLO graph with associated concrete tiles sizes. In the following changes I'll add code to build the graph from SymbolicTiledHloInstruction and use the tiled graph for Cost Model and Triton codegen.

PiperOrigin-RevId: 621903701
  • Loading branch information
olegshyshkov authored and tensorflower-gardener committed Apr 4, 2024
1 parent 9f5843e commit bf8b220
Show file tree
Hide file tree
Showing 4 changed files with 429 additions and 0 deletions.
33 changes: 33 additions & 0 deletions third_party/xla/xla/service/gpu/model/BUILD
Expand Up @@ -590,6 +590,39 @@ xla_cc_test(
],
)

cc_library(
name = "tiled_hlo_instruction",
srcs = ["tiled_hlo_instruction.cc"],
hdrs = ["tiled_hlo_instruction.h"],
deps = [
":indexing_map",
"//xla:util",
"//xla/hlo/ir:hlo",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
)

xla_cc_test(
name = "tiled_hlo_instruction_test",
srcs = ["tiled_hlo_instruction_test.cc"],
deps = [
":indexing_map",
":indexing_test_utils",
":tiled_hlo_instruction",
"//xla:shape_util",
"//xla/hlo/ir:hlo",
"//xla/tests:hlo_test_base",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_googletest//:gtest_main",
"@llvm-project//mlir:IR",
"@local_tsl//tsl/platform:statusor",
],
)

cc_library(
name = "symbolic_tile_analysis",
srcs = ["symbolic_tile_analysis.cc"],
Expand Down
118 changes: 118 additions & 0 deletions third_party/xla/xla/service/gpu/model/tiled_hlo_instruction.cc
@@ -0,0 +1,118 @@
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/service/gpu/model/tiled_hlo_instruction.h"

#include <cstddef>
#include <cstdint>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>

#include "absl/hash/hash.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/service/gpu/model/indexing_map.h"
#include "xla/util.h"

namespace xla {
namespace gpu {

size_t TiledHloInstruction::PtrHash::operator()(
const TiledHloInstruction* tiled_hlo) const {
return absl::HashOf(*tiled_hlo);
}

bool TiledHloInstruction::PtrEqual::operator()(
const TiledHloInstruction* lhs, const TiledHloInstruction* rhs) const {
return *lhs == *rhs;
}

bool operator==(const TiledHloInstruction& lhs,
const TiledHloInstruction& rhs) {
return lhs.hlo() == rhs.hlo() && lhs.tile_sizes() == rhs.tile_sizes() &&
lhs.tile_strides() == rhs.tile_strides() &&
lhs.block_id_to_tile_offsets_indexing() ==
rhs.block_id_to_tile_offsets_indexing();
}

bool operator!=(const TiledHloInstruction& lhs,
const TiledHloInstruction& rhs) {
return !(lhs == rhs);
}

/*static*/
absl::StatusOr<std::unique_ptr<TiledHloInstruction>>
TiledHloInstruction::Create(const HloInstruction* hlo,
std::vector<int64_t> tile_sizes,
std::vector<int64_t> tile_strides,
IndexingMap block_id_to_tile_offsets_indexing) {
int rank = hlo->shape().rank();

if (tile_sizes.size() != rank) {
return absl::InvalidArgumentError(
absl::StrCat("Number of tile sizes must be equal to the rank of the "
"hlo shape. tile_sizes = ",
tile_sizes.size(), ", hlo = ", hlo->ToString()));
}

if (tile_strides.size() != rank) {
return absl::InvalidArgumentError(
absl::StrCat("Number of tile strides must be equal to the rank of the "
"hlo shape. tile_sizes = ",
tile_strides.size(), ", hlo = ", hlo->ToString()));
}

if (block_id_to_tile_offsets_indexing.GetDimensionCount() != 1 ||
block_id_to_tile_offsets_indexing.GetSymbolCount() != 0) {
return absl::InvalidArgumentError(absl::StrCat(
"block_id_to_tile_offsets_indexing must have 1 dim and 0 symbols. "
"block_id_to_tile_offsets_indexing = ",
block_id_to_tile_offsets_indexing.ToString()));
}

if (block_id_to_tile_offsets_indexing.GetAffineMap().getNumResults() !=
rank) {
return absl::InvalidArgumentError(absl::StrCat(
"block_id_to_tile_offsets_indexing must have the same number of "
"results as the rank of the hlo shape. "
"block_id_to_tile_offsets_indexing = ",
block_id_to_tile_offsets_indexing.ToString(),
", hlo = ", hlo->ToString()));
}

return absl::WrapUnique(new TiledHloInstruction(
hlo, std::move(tile_sizes), std::move(tile_strides),
std::move(block_id_to_tile_offsets_indexing)));
}

std::string TiledHloInstruction::ToString() const {
std::stringstream ss;
ss << "hlo: " << hlo_->ToString() << "\n";
ss << "tile_sizes: {" << absl::StrJoin(tile_sizes_, ", ") << "}\n";
ss << "tile_strides: {" << absl::StrJoin(tile_strides_, ", ") << "}\n";
ss << "block_id_to_tile_offsets_indexing: "
<< block_id_to_tile_offsets_indexing_;
return ss.str();
}

} // namespace gpu
} // namespace xla
136 changes: 136 additions & 0 deletions third_party/xla/xla/service/gpu/model/tiled_hlo_instruction.h
@@ -0,0 +1,136 @@
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_SERVICE_GPU_MODEL_TILED_HLO_INSTRUCTION_H_
#define XLA_SERVICE_GPU_MODEL_TILED_HLO_INSTRUCTION_H_

#include <cstddef>
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "absl/status/statusor.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/service/gpu/model/indexing_map.h"

namespace xla {
namespace gpu {

// A wrapper around HloInstruction that represents a tiled HLO instruction.
//
// The class contains information required to emit this instruction in
// block-level codegen. Tile sizes and strides are constants and do not depend
// on the block id. Tile offsets are computed using an indexing map of form:
// `(block_id) -> (tile_offset0, tile_offset1, ...)`.
class TiledHloInstruction {
public:
// PtrHash and PtrEqual are helper classes to use in hash maps and sets that
// compare values behind the pointers. For example,
// absl::flat_hash_set<TiledHloInstruction*, PtrHash, PtrEqual> hlo_set;
struct PtrHash {
size_t operator()(const TiledHloInstruction* tiled_hlo) const;
};

struct PtrEqual {
bool operator()(const TiledHloInstruction* lhs,
const TiledHloInstruction* rhs) const;
};

// Creates an instance of TiledHloInstruction. Returns an error if any of the
// following preconditions is not met:
// * Number of tile sizes, strides should match HLO shape rank.
// * Number of result of `block_id_to_tile_offsets_indexing` should match HLO
// shape rank.
// * `block_id_to_tile_offsets_indexing` should have only 1 dimension and 0
// symbols.
static absl::StatusOr<std::unique_ptr<TiledHloInstruction>> Create(
const HloInstruction* hlo, std::vector<int64_t> tile_sizes,
std::vector<int64_t> tile_strides,
IndexingMap block_id_to_tile_offsets_indexing);

// Returns the original HLO instruction.
const HloInstruction* hlo() const { return hlo_; }

// Returns the tile sizes. The number of tile sizes is equal to the rank of
// the output shape.
const std::vector<int64_t>& tile_sizes() const { return tile_sizes_; }

// Returns the tile strides. The number of tile strides is equal to the rank
// of the output shape.
const std::vector<int64_t>& tile_strides() const { return tile_strides_; }

// Returns the indexing map from block_id to tile offsets. The map has a form
// of `(block_id) -> (tile_offset0, tile_offset1, ...)`. The number of tile
// offsets is equal to the rank of the output shape.
const IndexingMap& block_id_to_tile_offsets_indexing() const {
return block_id_to_tile_offsets_indexing_;
}

const TiledHloInstruction* operand(int64_t operand_id) const {
return operands_[operand_id];
}

const std::vector<TiledHloInstruction*>& operands() const {
return operands_;
}

void AppendOperand(TiledHloInstruction* operand) {
operands_.push_back(operand);
}

std::string ToString() const;

private:
TiledHloInstruction(const HloInstruction* hlo,
std::vector<int64_t> tile_sizes,
std::vector<int64_t> tile_strides,
IndexingMap block_id_to_tile_offsets_indexing)
: hlo_(hlo),
tile_sizes_(std::move(tile_sizes)),
tile_strides_(std::move(tile_strides)),
block_id_to_tile_offsets_indexing_(
std::move(block_id_to_tile_offsets_indexing)) {}

// Pointer to the original HLO instruction.
const HloInstruction* hlo_;

// Tile sizes and strides.
std::vector<int64_t> tile_sizes_;
std::vector<int64_t> tile_strides_;

// Indexing map from block_id to tile offsets.
IndexingMap block_id_to_tile_offsets_indexing_;

// Operands of the instruction in the tiled computation graph.
std::vector<TiledHloInstruction*> operands_;
};

bool operator==(const TiledHloInstruction& lhs, const TiledHloInstruction& rhs);
bool operator!=(const TiledHloInstruction& lhs, const TiledHloInstruction& rhs);

template <typename H>
H AbslHashValue(H h, const TiledHloInstruction& tiled_hlo_instruction) {
return H::combine(std::move(h), tiled_hlo_instruction.hlo(),
tiled_hlo_instruction.tile_sizes(),
tiled_hlo_instruction.tile_strides(),
tiled_hlo_instruction.block_id_to_tile_offsets_indexing());
}

} // namespace gpu
} // namespace xla

#endif // XLA_SERVICE_GPU_MODEL_TILED_HLO_INSTRUCTION_H_

0 comments on commit bf8b220

Please sign in to comment.