Skip to content

Commit

Permalink
[XLA:GPU][MLIR-based emitters] Add verifier for apply_indexing op.
Browse files Browse the repository at this point in the history
Reverts 2ba594d

PiperOrigin-RevId: 627542099
  • Loading branch information
pifon2a authored and tensorflower-gardener committed Apr 24, 2024
1 parent b251941 commit d1b2564
Show file tree
Hide file tree
Showing 18 changed files with 98 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ limitations under the License.

// Ensure the included flatbuffers.h is the same version as when this file was
// generated, otherwise it may not be compatible.
static_assert(FLATBUFFERS_VERSION_MAJOR == 24 &&
FLATBUFFERS_VERSION_MINOR == 3 &&
FLATBUFFERS_VERSION_REVISION == 7,
static_assert(FLATBUFFERS_VERSION_MAJOR == 23 &&
FLATBUFFERS_VERSION_MINOR == 5 &&
FLATBUFFERS_VERSION_REVISION == 26,
"Non-compatible flatbuffers version included");

namespace tflite {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ limitations under the License.

// Ensure the included flatbuffers.h is the same version as when this file was
// generated, otherwise it may not be compatible.
static_assert(FLATBUFFERS_VERSION_MAJOR == 24 &&
FLATBUFFERS_VERSION_MINOR == 3 &&
FLATBUFFERS_VERSION_REVISION == 7,
static_assert(FLATBUFFERS_VERSION_MAJOR == 23 &&
FLATBUFFERS_VERSION_MINOR == 5 &&
FLATBUFFERS_VERSION_REVISION == 26,
"Non-compatible flatbuffers version included");

namespace tflite {
Expand Down
6 changes: 3 additions & 3 deletions tensorflow/lite/delegates/gpu/cl/serialization_generated.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ limitations under the License.

// Ensure the included flatbuffers.h is the same version as when this file was
// generated, otherwise it may not be compatible.
static_assert(FLATBUFFERS_VERSION_MAJOR == 24 &&
FLATBUFFERS_VERSION_MINOR == 3 &&
FLATBUFFERS_VERSION_REVISION == 7,
static_assert(FLATBUFFERS_VERSION_MAJOR == 23 &&
FLATBUFFERS_VERSION_MINOR == 5 &&
FLATBUFFERS_VERSION_REVISION == 26,
"Non-compatible flatbuffers version included");

#include "gpu_model_generated.h"
Expand Down
6 changes: 3 additions & 3 deletions tensorflow/lite/delegates/gpu/common/gpu_model_generated.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ limitations under the License.

// Ensure the included flatbuffers.h is the same version as when this file was
// generated, otherwise it may not be compatible.
static_assert(FLATBUFFERS_VERSION_MAJOR == 24 &&
FLATBUFFERS_VERSION_MINOR == 3 &&
FLATBUFFERS_VERSION_REVISION == 7,
static_assert(FLATBUFFERS_VERSION_MAJOR == 23 &&
FLATBUFFERS_VERSION_MINOR == 5 &&
FLATBUFFERS_VERSION_REVISION == 26,
"Non-compatible flatbuffers version included");

#include "serialization_base_generated.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ limitations under the License.

// Ensure the included flatbuffers.h is the same version as when this file was
// generated, otherwise it may not be compatible.
static_assert(FLATBUFFERS_VERSION_MAJOR == 24 &&
FLATBUFFERS_VERSION_MINOR == 3 &&
FLATBUFFERS_VERSION_REVISION == 7,
static_assert(FLATBUFFERS_VERSION_MAJOR == 23 &&
FLATBUFFERS_VERSION_MINOR == 5 &&
FLATBUFFERS_VERSION_REVISION == 26,
"Non-compatible flatbuffers version included");

namespace tflite {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ limitations under the License.

// Ensure the included flatbuffers.h is the same version as when this file was
// generated, otherwise it may not be compatible.
static_assert(FLATBUFFERS_VERSION_MAJOR == 24 &&
FLATBUFFERS_VERSION_MINOR == 3 &&
FLATBUFFERS_VERSION_REVISION == 7,
static_assert(FLATBUFFERS_VERSION_MAJOR == 23 &&
FLATBUFFERS_VERSION_MINOR == 5 &&
FLATBUFFERS_VERSION_REVISION == 26,
"Non-compatible flatbuffers version included");

namespace tflite {
Expand Down
6 changes: 3 additions & 3 deletions tensorflow/lite/schema/conversion_metadata_generated.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ limitations under the License.

// Ensure the included flatbuffers.h is the same version as when this file was
// generated, otherwise it may not be compatible.
static_assert(FLATBUFFERS_VERSION_MAJOR == 24 &&
FLATBUFFERS_VERSION_MINOR == 3 &&
FLATBUFFERS_VERSION_REVISION == 7,
static_assert(FLATBUFFERS_VERSION_MAJOR == 23 &&
FLATBUFFERS_VERSION_MINOR == 5 &&
FLATBUFFERS_VERSION_REVISION == 26,
"Non-compatible flatbuffers version included");

namespace tflite {
Expand Down
6 changes: 3 additions & 3 deletions tensorflow/lite/schema/schema_generated.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ limitations under the License.

// Ensure the included flatbuffers.h is the same version as when this file was
// generated, otherwise it may not be compatible.
static_assert(FLATBUFFERS_VERSION_MAJOR == 24 &&
FLATBUFFERS_VERSION_MINOR == 3 &&
FLATBUFFERS_VERSION_REVISION == 7,
static_assert(FLATBUFFERS_VERSION_MAJOR == 23 &&
FLATBUFFERS_VERSION_MINOR == 5 &&
FLATBUFFERS_VERSION_REVISION == 26,
"Non-compatible flatbuffers version included");

namespace tflite {
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/lite/tools/cmake/modules/flatbuffers.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ OverridableFetchContent_Declare(
flatbuffers
GIT_REPOSITORY https://github.com/google/flatbuffers
# Sync with tensorflow/third_party/flatbuffers/workspace.bzl
GIT_TAG v24.3.7
GIT_TAG v23.5.26
GIT_SHALLOW TRUE
GIT_PROGRESS TRUE
SOURCE_DIR "${CMAKE_BINARY_DIR}/flatbuffers"
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/tools/ci_build/release/requirements_common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This will change in the future.
absl-py ~= 1.0.0
astunparse ~= 1.6.3
flatbuffers ~= 24.3.7
flatbuffers ~= 23.5.26
google_pasta ~= 0.2
h5py ~= 3.10.0 # Earliest version for Python 3.12
ml_dtypes ~= 0.3.1
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/tools/pip_package/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def standard_or_nightly(standard, nightly):
REQUIRED_PACKAGES = [
'absl-py >= 1.0.0',
'astunparse >= 1.6.0',
'flatbuffers >= 24.3.7',
'flatbuffers >= 23.5.26',
'gast >=0.2.1,!=0.5.0,!=0.5.1,!=0.5.2',
'google_pasta >= 0.1.1',
'h5py >= 3.10.0',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# This will change in the future.
absl-py ~= 1.0.0
astunparse ~= 1.6.3
flatbuffers ~= 24.3.7
flatbuffers ~= 23.5.26
google_pasta ~= 0.2
h5py ~= 3.10.0 # Earliest version for Python 3.12
ml_dtypes ~= 0.3.1
Expand Down
4 changes: 2 additions & 2 deletions third_party/flatbuffers/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
# _FLATBUFFERS_GIT_COMMIT / _FLATBUFFERS_SHA256 were added due to an urgent change being made to
# Flatbuffers that needed to be updated in order for Flatbuffers/TfLite be compatible with Android
# API level >= 23. They can be removed next flatbuffers offical release / update.
_FLATBUFFERS_GIT_COMMIT = "6ff9e90e7e399f3977e99a315856b57c8afe5b4d"
_FLATBUFFERS_GIT_COMMIT = "7d6d99c6befa635780a4e944d37ebfd58e68a108"

# curl -L https://github.com/google/flatbuffers/archive/<_FLATBUFFERS_GIT_COMMIT>.tar.gz | shasum -a 256
_FLATBUFFERS_SHA256 = "f4b3dfed9f8f4f0fd9f857fe96a46199cb5745ddb458cad20caf6837230ea188"
_FLATBUFFERS_SHA256 = "d27761f6b2fb1017ec00ed317a7b98cb7aed86b81d90528b498fb17ec13579a1"

def repo():
tf_http_archive(
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/fusions/mlir/ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ cc_library(
hdrs = ["xla_gpu_ops.h"],
deps = [
":xla_gpu_ops_inc_gen",
"//xla/service/gpu/model:indexing_map",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:BytecodeOpInterface",
"@llvm-project//mlir:CallOpInterfaces",
Expand Down
44 changes: 43 additions & 1 deletion third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ limitations under the License.
#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"

#include <cstdint>
#include <utility>
#include <vector>

#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep
Expand All @@ -32,11 +34,13 @@ limitations under the License.
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/InliningUtils.h" // from @llvm-project
#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.cc.inc"
#include "xla/service/gpu/model/indexing_map.h"

namespace xla {
namespace gpu {
namespace {

using mlir::AffineMap;
using mlir::failure;
using mlir::LogicalResult;
using mlir::OpBuilder;
Expand Down Expand Up @@ -209,7 +213,7 @@ mlir::ParseResult ApplyIndexingOp::parse(mlir::OpAsmParser &parser,

void ApplyIndexingOp::print(mlir::OpAsmPrinter &p) {
mlir::AffineMapAttr affine_map_attr = getMapAttr();
mlir::AffineMap affine_map = affine_map_attr.getAffineMap();
AffineMap affine_map = affine_map_attr.getAffineMap();
p << " " << affine_map_attr;

auto lower_bounds = getLowerBounds();
Expand Down Expand Up @@ -244,6 +248,44 @@ void ApplyIndexingOp::print(mlir::OpAsmPrinter &p) {
"map", "lower_bounds", "upper_bounds"});
}

LogicalResult ApplyIndexingOp::verify() {
auto affine_map = getMapAttr().getAffineMap();
unsigned num_variables = affine_map.getNumDims() + affine_map.getNumSymbols();
if (getOperands().size() != num_variables ||
getLowerBounds().size() != num_variables ||
getUpperBounds().size() != num_variables) {
return emitOpError(
"operand, lower_bounds, upper_bounds count and affine map dimension "
"and symbol count must match");
}
IndexingMap indexing_map = getIndexingMap();
if (indexing_map.IsKnownEmpty()) {
return emitOpError("indexing map is empty");
}
return success();
}

IndexingMap ApplyIndexingOp::getIndexingMap() {
auto lower_bounds = getLowerBounds();
auto upper_bounds = getUpperBounds();

AffineMap affine_map = getMapAttr().getAffineMap();
unsigned num_dimensions = affine_map.getNumDims();
std::vector<DimVar> dim_vars;
dim_vars.reserve(num_dimensions);
for (int id = 0; id < num_dimensions; ++id) {
dim_vars.emplace_back(Interval{lower_bounds[id], upper_bounds[id]});
}
unsigned num_symbols = affine_map.getNumSymbols();
std::vector<RangeVar> range_vars;
range_vars.reserve(num_symbols);
for (int id = 0; id < num_symbols; ++id) {
range_vars.emplace_back(Interval{lower_bounds[id], upper_bounds[id]});
}
return IndexingMap(affine_map, std::move(dim_vars), std::move(range_vars),
/*rt_vars=*/{});
}

//===----------------------------------------------------------------------===//
// AtomicRMWOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project // IWYU pragma: keep
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project // IWYU pragma : keep
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project // IWYU pragma : keep
#include "xla/service/gpu/model/indexing_map.h"

#define GET_OP_CLASSES
#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.h.inc"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,13 @@ def ApplyIndexingOp : XLAGPU_Op<"apply_indexing", [Pure]> {
DenseI64ArrayAttr:$upper_bounds);
let results = (outs Variadic<Index>);


let extraClassDeclaration = [{
// Return an indexing map constructed from affine_map and the bounds.
xla::gpu::IndexingMap getIndexingMap();
}];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

#endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS
17 changes: 17 additions & 0 deletions third_party/xla/xla/service/gpu/fusions/mlir/tests/invalid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: mlir_fusions_opt %s -split-input-file -verify-diagnostics

#map0 = affine_map<(d0, d1)[s0] -> (d0, d1 + s0)>
func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) {
// expected-error @+1 {{operand, lower_bounds, upper_bounds count and affine map dimension and symbol count must match}}
%0:2 = xla_gpu.apply_indexing #map0 (%d0 in [0, 2])
func.return %0#0, %0#1 : index, index
}

// -----

#map0 = affine_map<(d0) -> (d0)>
func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> index {
// expected-error @+1 {{indexing map is empty}}
%0 = xla_gpu.apply_indexing #map0 (%d0 in [100, 0])
func.return %0 : index
}

0 comments on commit d1b2564

Please sign in to comment.