Skip to content

Commit

Permalink
Don't rematerialize ReduceOp
Browse files Browse the repository at this point in the history
  • Loading branch information
peterbell10 committed Apr 10, 2023
1 parent c404afe commit 19d31c6
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ bool expensiveToRemat(Operation *op, Attribute &targetEncoding) {
return expensiveLoadOrStore(op, targetEncoding);
if (isa<tensor::ExtractSliceOp, triton::gpu::AllocTensorOp,
triton::gpu::InsertSliceAsyncOp, triton::AtomicRMWOp,
triton::AtomicCASOp, triton::DotOp>(op))
triton::AtomicCASOp, triton::DotOp, triton::ReduceOp>(op))
return true;
if (isa<scf::YieldOp, scf::ForOp, scf::IfOp, scf::WhileOp, scf::ConditionOp>(
op))
Expand Down
3 changes: 3 additions & 0 deletions test/TritonGPU/combine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1013,12 +1013,15 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
%1 = triton_gpu.convert_layout %0 : (tensor<2xi32, #blocked1>) -> tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%2 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x2xi32, #blocked>
%3 = "triton_gpu.cmpi"(%2, %cst_0) {predicate = 2 : i64} : (tensor<1x2xi32, #blocked>, tensor<1x2xi32, #blocked>) -> tensor<1x2xi1, #blocked>
// CHECK-DAG: }) {axis = 1 : i32}
%4 = "tt.reduce" (%cst) ({
^bb0(%arg3: i32, %arg4: i32):
%add = arith.addi %arg3, %arg4 : i32
tt.reduce.return %add : i32
}) {axis = 1 : i32} : (tensor<1x2xi32, #blocked>) -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
// CHECK-NEXT: triton_gpu.convert_layout {{%.*}} : (tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%5 = triton_gpu.convert_layout %4 : (tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<1xi32, #blocked1>
// CHECK-NOT: triton_gpu.convert_layout
%6 = triton_gpu.convert_layout %5 : (tensor<1xi32, #blocked1>) -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%7 = tt.expand_dims %6 {axis = 1 : i32} : (tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<1x1xi32, #blocked2>
%8 = triton_gpu.convert_layout %7 : (tensor<1x1xi32, #blocked2>) -> tensor<1x1xi32, #blocked>
Expand Down

0 comments on commit 19d31c6

Please sign in to comment.