Skip to content

Commit

Permalink
[xla:gpu] Support capturing lmhlo.reinterpret_cast in cuda graphs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 528626714
  • Loading branch information
anlunx authored and tensorflower-gardener committed May 2, 2023
1 parent 84a9aa9 commit 0666daf
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ struct MemcpyOpCapture : public OpCapturePattern {
// Capture pure operations by cloning them into graph capture function.
struct ConstantOpCapture : public CloneOp<arith::ConstantOp> {};
struct ViewOpCapture : public CloneOp<memref::ViewOp> {};
struct ReinterpretCastOpCapture : public CloneOp<memref::ReinterpretCastOp> {};

//===----------------------------------------------------------------------===//

Expand Down Expand Up @@ -433,6 +434,7 @@ void OutlineCudaGraphsPass::runOnOperation() {
patterns.emplace_back(new ConstantOpCapture());
patterns.emplace_back(new ViewOpCapture());
patterns.emplace_back(new MemcpyOpCapture());
patterns.emplace_back(new ReinterpretCastOpCapture());
}

if (cuda_graph_level_ >= 2) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -535,3 +535,42 @@ module attributes {gpu.container_module} {
// CHECK: gpu.memcpy
// CHECK: gpu.memcpy
// CHECK-NEXT: return

// -----
// Check that memref.reinterpret_cast operations are cloned into the graph
// capture function.

module attributes {gpu.container_module} {

gpu.module @gpu_module attributes {binary = "kernel binary"} {
gpu.func @fn0(%arg0: memref<16xi8, strided<[1], offset: 0>>) kernel { gpu.return }
gpu.func @fn1(%arg0: memref<16xi8, strided<[1], offset: 0>>) kernel { gpu.return }
}

// CHECK: @func(%[[ARG0:.*]]: memref<16xi8>)
func.func @func(%arg0: memref<16xi8>) {
%c1 = arith.constant 1 : index
%view = memref.reinterpret_cast %arg0 to offset: [0], sizes: [16], strides: [1]: memref<16xi8> to memref<16xi8, strided<[1], offset: 0>>

call @external() : () -> ()

// CHECK: call @xla.gpu.cuda.graph.launch(%[[ARG0]])
// CHECK-SAME: {capture = @xla.gpu.cuda.graph.capture}
// CHECK-NEXT: return
gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1)
threads in (%c1, %c1, %c1) args(%view : memref<16xi8, strided<[1], offset: 0>>)
gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1)
threads in (%c1, %c1, %c1) args(%view : memref<16xi8, strided<[1], offset: 0>>)

func.return
}

func.func private @external()
}

// CHECK: func @xla.gpu.cuda.graph.capture
// CHECK-NEXT: arith.constant 1
// CHECK-NEXT: memref.reinterpret_cast
// CHECK-NEXT: gpu.launch_func @gpu_module::@fn0
// CHECK-NEXT: gpu.launch_func @gpu_module::@fn1
// CHECK-NEXT: return

0 comments on commit 0666daf

Please sign in to comment.