Skip to content

Commit

Permalink
[AMD][gfx11] Fix BF16 wmma instr generation (#4135)
Browse files Browse the repository at this point in the history
- Pack bf16 elements to int16 vectors;
- Add a lit test;
- BF16 testcases from test_core.py::test_dot are passed for now;

The core Triton is a small number of people, and we receive many PRs
(thank
you!).  To help us review your code more quickly, **if you are a new
contributor (less than 3 PRs merged) we ask that you complete the
following
tasks and include the filled-out checklist in your PR description.**

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.

- [x] I am not making a trivial change, such as fixing a typo in a
comment.

- [x] I have written a PR description following these
  [rules](https://cbea.ms/git-commit/#why-not-how).

- [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.

- Select one of the following.
  - [x] I have added tests.
    - `/test` for `lit` tests

- Select one of the following.
- [x] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
    and using the instructions it generates is not minimal.)

Signed-off-by: Ilya Veselov <iveselov.nn@gmail.com>
  • Loading branch information
joviliast committed Jun 13, 2024
1 parent 3d5cd67 commit 4a1ea8e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
16 changes: 15 additions & 1 deletion test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 | FileCheck %s
// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx1100 | FileCheck %s

#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}>
#mma = #triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>
Expand Down Expand Up @@ -27,6 +27,20 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
tt.return
}

// CHECK-LABEL: wmma_dot_bf16
tt.func @wmma_dot_bf16(%arg0: tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg1: tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg2: tensor<16x16xbf16, #mma>) {
// CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16)>
// CHECK: llvm.bitcast %{{.*}} : vector<16xbf16> to vector<16xi16>
// CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16)>
// CHECK: llvm.bitcast %{{.*}} : vector<16xbf16> to vector<16xi16>
// CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16)>
// CHECK: llvm.mlir.undef : vector<16xbf16>
// CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<16xbf16>
// CHECK: rocdl.wmma.bf16.16x16x16.bf16 {{.*}} : (vector<16xi16>, vector<16xi16>, vector<16xbf16>, i1) -> vector<16xbf16>
%0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> * tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> -> tensor<16x16xbf16, #mma>
tt.return
}

// CHECK-LABEL: wmma_dot_int8_32
tt.func @wmma_dot_int8_32(%arg0: tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg1: tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg2: tensor<16x16xi32, #mma>) {
// CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)>
Expand Down
4 changes: 3 additions & 1 deletion third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,10 @@ getValuesFromDotOperandLayoutStruct(ConversionPatternRewriter &rewriter,
}

Value convertedElems;
if (type.isBF16() || type.isF16()) {
if (type.isF16()) {
convertedElems = rawElems;
} else if (type.isBF16()) {
convertedElems = bitcast(rawElems, vec_ty(i16_ty, kWidth));
} else {
convertedElems = bitcast(
rawElems, vec_ty(i32_ty, kWidth * type.getIntOrFloatBitWidth() /
Expand Down

0 comments on commit 4a1ea8e

Please sign in to comment.