Skip to content

Commit

Permalink
Use data layout attributes for index bit widths.
Browse files Browse the repository at this point in the history
This fixes an issue in our code where we use too wide indices, and also works
around a bug in LLVM's separate-const-offset-from-gep pass, which produces
incorrect results in the presence of truncations. After this change, we will
use the same bit width for GEPs as for other indices, so there will never be
any truncs.

PiperOrigin-RevId: 625595331
  • Loading branch information
jreiffers authored and tensorflower-gardener committed Apr 17, 2024
1 parent 6e7340c commit 0c3ce98
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 25 deletions.
3 changes: 2 additions & 1 deletion third_party/xla/xla/service/gpu/fusions/mlir/BUILD
Expand Up @@ -159,6 +159,7 @@ cc_library(
":type_util",
"//xla:shape_util",
"//xla:status_macros",
"//xla:util",
"//xla/hlo/ir:hlo",
"//xla/mlir_hlo",
"//xla/mlir_hlo:mhlo_passes",
Expand Down Expand Up @@ -189,11 +190,11 @@ cc_library(
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:AffineToStandard",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:BufferizationDialect",
"@llvm-project//mlir:BufferizationInterfaces",
"@llvm-project//mlir:BuiltinToLLVMIRTranslation",
"@llvm-project//mlir:ComplexToStandard",
"@llvm-project//mlir:ControlFlowDialect",
"@llvm-project//mlir:DLTIDialect",
"@llvm-project//mlir:DataLayoutInterfaces",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FuncExtensions",
Expand Down
8 changes: 4 additions & 4 deletions third_party/xla/xla/service/gpu/fusions/mlir/lower_tensors.cc
Expand Up @@ -43,6 +43,7 @@ limitations under the License.
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/IR/ValueRange.h" // from @llvm-project
#include "mlir/Interfaces/DataLayoutInterfaces.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
Expand Down Expand Up @@ -137,10 +138,9 @@ mlir::LLVM::GEPOp CreateGep(mlir::Operation* op,
rewriter.setInsertionPoint(op);
Value index = rewriter.create<mlir::affine::AffineApplyOp>(
tensor.getLoc(), linearize_map, indices);
auto index_ty =
ShapeUtil::ElementsIn(byte_shape) < std::numeric_limits<int32_t>::max()
? rewriter.getI32Type()
: rewriter.getI64Type();
auto index_ty = rewriter.getIntegerType(
mlir::DataLayout::closest(rewriter.getInsertionBlock()->getParentOp())
.getTypeSizeInBits(index.getType()));
index = rewriter.create<mlir::arith::IndexCastUIOp>(tensor.getLoc(), index_ty,
index);

Expand Down
Expand Up @@ -52,7 +52,10 @@ class LowerToLLVMPass : public impl::LowerToLLVMPassBase<LowerToLLVMPass> {

void runOnOperation() override {
// Populate type conversions.
mlir::LLVMTypeConverter type_converter(getOperation().getContext());
mlir::LowerToLLVMOptions llvm_opts(&getContext(),
mlir::DataLayout(getOperation()));
mlir::LLVMTypeConverter type_converter(getOperation().getContext(),
llvm_opts);
mlir::LLVMConversionTarget target(*getOperation().getContext());

// Populate patterns.
Expand Down
Expand Up @@ -44,6 +44,7 @@ limitations under the License.
#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" // from @llvm-project
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" // from @llvm-project
#include "mlir/Dialect/DLTI/DLTI.h" // from @llvm-project
#include "mlir/Dialect/Func/Extensions/InlinerExtension.h" // from @llvm-project
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project
Expand Down Expand Up @@ -92,6 +93,7 @@ limitations under the License.
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/stream_executor/device_description.h"
#include "xla/util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"

Expand Down Expand Up @@ -141,6 +143,21 @@ void AddRanges(llvm::Function* func, const LaunchDimensions& launch_dims,
}
}

bool Needs64Bits(const Shape& shape) {
return shape.IsArray() ? !IsInt32(ShapeUtil::ElementsIn(shape))
: absl::c_any_of(shape.tuple_shapes(), Needs64Bits);
}

bool Needs64BitIndices(const HloComputation* computation) {
for (auto* instr : computation->instructions()) {
if (Needs64Bits(instr->shape()) ||
absl::c_any_of(instr->called_computations(), Needs64BitIndices)) {
return true;
}
}
return false;
}

} // namespace

Value MlirFusionEmitterBase::EmitBlockId(mlir::ImplicitLocOpBuilder& builder,
Expand Down Expand Up @@ -302,12 +319,12 @@ MlirFusionEmitterBase::CreateMLIRModule(
mlir::MLIRContext& context, const HloFusionInstruction& fusion,
const std::string& entry_function_name,
const BufferAssignment* buffer_assignment) const {
context.loadDialect<mlir::tensor::TensorDialect, mlir::func::FuncDialect,
mlir::affine::AffineDialect, mlir::arith::ArithDialect,
mlir::cf::ControlFlowDialect, mlir::math::MathDialect,
mlir::scf::SCFDialect, mlir::mhlo::MhloDialect,
mlir::gpu::GPUDialect, mlir::NVVM::NVVMDialect,
xla::gpu::XlaGpuDialect>();
context.loadDialect<mlir::DLTIDialect, mlir::tensor::TensorDialect,
mlir::func::FuncDialect, mlir::affine::AffineDialect,
mlir::arith::ArithDialect, mlir::cf::ControlFlowDialect,
mlir::math::MathDialect, mlir::scf::SCFDialect,
mlir::mhlo::MhloDialect, mlir::gpu::GPUDialect,
mlir::NVVM::NVVMDialect, xla::gpu::XlaGpuDialect>();
mlir::DialectRegistry registry;
mlir::func::registerInlinerExtension(registry);
mlir::registerBuiltinDialectTranslation(registry);
Expand Down Expand Up @@ -451,6 +468,15 @@ absl::Status MlirFusionEmitterBase::EmitMlir(
*epilogue, subgraph_to_mlir_fn[&*epilogue], call_targets));
}

int index_bitwidth =
Needs64BitIndices(fusion.fused_instructions_computation()) ? 64 : 32;
mlir::OpBuilder b(module->getContext());
auto index_layout = mlir::DataLayoutEntryAttr::get(
b.getIndexType(), b.getI32IntegerAttr(index_bitwidth));
module->setAttr(
mlir::DLTIDialect::kDataLayoutAttrName,
mlir::DataLayoutSpecAttr::get(module->getContext(), {index_layout}));

return EmitEntryFunction(computations, call_targets, entry_function, fusion);
}

Expand Down
Expand Up @@ -170,11 +170,9 @@ TEST_F(MlirFusionEmitterTest, CreateLLVMModule) {
TF_ASSERT_OK_AND_ASSIGN(auto filecheck_result, RunFileCheck(out, R"(
// CHECK: define void @fusion(ptr noalias %[[IN:.*]], ptr noalias %[[OUT:.*]])
// CHECK: %[[TID:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
// CHECK: %[[EXT:.*]] = sext i32 %[[TID]] to i64
// CHECK: %[[TRUNC:.*]] = trunc i64 %[[EXT]] to i32
// CHECK: %[[IN_PTR:.*]] = getelementptr inbounds float, ptr %[[IN]], i32 %[[TRUNC]]
// CHECK: %[[IN_PTR:.*]] = getelementptr inbounds float, ptr %[[IN]], i32 %[[TID]]
// CHECK: %[[VAL:.*]] = load float, ptr %[[IN_PTR]], align 4
// CHECK: %[[OUT_PTR:.*]] = getelementptr inbounds float, ptr %[[OUT]], i32 %[[TRUNC]]
// CHECK: %[[OUT_PTR:.*]] = getelementptr inbounds float, ptr %[[OUT]], i32 %[[TID]]
// CHECK: store float %[[VAL]], ptr %[[OUT_PTR]], align 4
// CHECK: ret void
)"));
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/fusions/mlir/tests/BUILD
Expand Up @@ -16,6 +16,7 @@ xla_cc_binary(
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:ComplexDialect",
"@llvm-project//mlir:DLTIDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FuncExtensions",
"@llvm-project//mlir:GPUDialect",
Expand Down
@@ -1,6 +1,6 @@
// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-lower-tensors | FileCheck %s

module {
module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<index, 32 : i32>>} {
func.func private @add(%arg0: f32, %arg1: f32) -> f32 {
%sum = arith.addf %arg0, %arg1 : f32
func.return %sum : f32
Expand Down Expand Up @@ -72,7 +72,7 @@ module {
// CHECK: @layout(%[[ARG0:.*]]: !llvm.ptr,
// CHECK-SAME: %[[X:.*]]: index, %[[Y:.*]]: index
// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]](%[[X]], %[[Y]])
// CHECK: %[[IDX_CAST:.*]] = arith.index_castui %[[IDX]] : index to i32
// CHECK: %[[IDX_CAST:.*]] = arith.index_castui %[[IDX]] : index to i64
// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %[[ARG0]][%[[IDX_CAST]]]
// CHECK: llvm.load %[[PTR]]

Expand Down Expand Up @@ -110,7 +110,7 @@ module {
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
// CHECK: %[[CAST:.*]] = arith.index_castui %[[I]] : index to i32
// CHECK: %[[CAST:.*]] = arith.index_castui %[[I]] : index to i64
// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %[[ARG0]][%[[CAST]]]
// CHECK: llvm.store {{.*}}, %[[PTR]]
// CHECK: %[[INBOUNDS:.*]] = arith.cmpi
Expand Down
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/Dialect/Complex/IR/Complex.h" // from @llvm-project
#include "mlir/Dialect/DLTI/DLTI.h" // from @llvm-project
#include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project
Expand All @@ -32,12 +33,13 @@ limitations under the License.

int main(int argc, char **argv) {
mlir::DialectRegistry registry;
registry.insert<mlir::tensor::TensorDialect, mlir::func::FuncDialect,
mlir::affine::AffineDialect, mlir::arith::ArithDialect,
mlir::complex::ComplexDialect, mlir::math::MathDialect,
mlir::scf::SCFDialect, mlir::mhlo::MhloDialect,
mlir::LLVM::LLVMDialect, mlir::gpu::GPUDialect,
mlir::mhlo::MhloDialect, xla::gpu::XlaGpuDialect>();
registry.insert<mlir::DLTIDialect, mlir::tensor::TensorDialect,
mlir::func::FuncDialect, mlir::affine::AffineDialect,
mlir::arith::ArithDialect, mlir::complex::ComplexDialect,
mlir::math::MathDialect, mlir::scf::SCFDialect,
mlir::mhlo::MhloDialect, mlir::LLVM::LLVMDialect,
mlir::gpu::GPUDialect, mlir::mhlo::MhloDialect,
xla::gpu::XlaGpuDialect>();
mlir::func::registerAllExtensions(registry);
mlir::registerCanonicalizerPass();
mlir::registerCSEPass();
Expand Down

0 comments on commit 0c3ce98

Please sign in to comment.