-
Notifications
You must be signed in to change notification settings - Fork 74k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PR #8874: [GPU] Use NCCL user buffers for collective permute and all-…
…to-all Imported from GitHub PR openxla/xla#8874 This PR enables XLA to take advantage of NCCL user buffers for ncclSend/ncclRecv when `--xla_gpu_enable_nccl_user_buffers=true` is used. Requires NCCL 2.20 Copybara import of the project: -- 98acdf27d4eba6b19652a76d3f7dcd6630349fc5 by Trevor Morris <tmorris@nvidia.com>: Use NCCL user buffers for ncclSend/ncclRecv ops -- bcc289b49bcf2086b50a86a2381ea1b80acd3dd2 by Trevor Morris <tmorris@nvidia.com>: Include memory space in buffers for collective permute and send/recv -- 4a83d8906b6b5e305dad23fc1d8b9a5069637279 by Trevor Morris <tmorris@nvidia.com>: Don't offload send, recv -- 0083a418c4ab119ed5a0eb061113104980476943 by Trevor Morris <tmorris@nvidia.com>: Fix conditional Merging this change closes #8874 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#8874 from trevor-m:p2p-user-buffers 0083a418c4ab119ed5a0eb061113104980476943 PiperOrigin-RevId: 615104094
- Loading branch information
1 parent
b960be6
commit 6d82b56
Showing
13 changed files
with
179 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
12 changes: 12 additions & 0 deletions
12
tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/cast_bf16.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s | ||
// Ensure cast with bfloat16 roundtrip exactly | ||
|
||
func.func @main(tensor<4x5xbf16>) -> tensor<4x5xbf16> { | ||
^bb0(%arg0: tensor<4x5xbf16>): | ||
// CHECK-LABEL: @main | ||
// CHECK: (tensor<4x5xbf16>) -> tensor<4x5xf32> | ||
// CHECK-NEXT: (tensor<4x5xf32>) -> tensor<4x5xbf16> | ||
%0 = "tfl.cast" (%arg0) : (tensor<4x5xbf16>) -> tensor<4x5xf32> loc("cast1") | ||
%1 = "tfl.cast" (%0) : (tensor<4x5xf32>) -> tensor<4x5xbf16> loc("cast2") | ||
func.return %1 : tensor<4x5xbf16> | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
74 changes: 74 additions & 0 deletions
74
tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/cast_bf16.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck %s | ||
|
||
func.func @main(tensor<4x5xbf16>) -> tensor<4x5xbf16> { | ||
^bb0(%arg0: tensor<4x5xbf16>): | ||
|
||
// CHECK: { | ||
// CHECK-NEXT: version: 3, | ||
// CHECK-NEXT: operator_codes: [ { | ||
// CHECK-NEXT: deprecated_builtin_code: 53, | ||
// CHECK-NEXT: version: 7, | ||
// CHECK-NEXT: builtin_code: CAST | ||
// CHECK-NEXT: } ], | ||
// CHECK-NEXT: subgraphs: [ { | ||
// CHECK-NEXT: tensors: [ { | ||
// CHECK-NEXT: shape: [ 4, 5 ], | ||
// CHECK-NEXT: type: BFLOAT16, | ||
// CHECK-NEXT: buffer: 1, | ||
// CHECK-NEXT: name: "arg0", | ||
// CHECK-NEXT: quantization: { | ||
// CHECK-EMPTY: | ||
// CHECK-NEXT: }, | ||
// CHECK-NEXT: has_rank: true | ||
// CHECK-NEXT: }, { | ||
// CHECK-NEXT: shape: [ 4, 5 ], | ||
// CHECK-NEXT: buffer: 2, | ||
// CHECK-NEXT: name: "cast1", | ||
// CHECK-NEXT: quantization: { | ||
// CHECK-EMPTY: | ||
// CHECK-NEXT: }, | ||
// CHECK-NEXT: has_rank: true | ||
// CHECK-NEXT: }, { | ||
// CHECK-NEXT: shape: [ 4, 5 ], | ||
// CHECK-NEXT: type: BFLOAT16, | ||
// CHECK-NEXT: buffer: 3, | ||
// CHECK-NEXT: name: "cast2", | ||
// CHECK-NEXT: quantization: { | ||
// CHECK-EMPTY: | ||
// CHECK-NEXT: }, | ||
// CHECK-NEXT: has_rank: true | ||
// CHECK-NEXT: } ], | ||
// CHECK-NEXT: inputs: [ 0 ], | ||
// CHECK-NEXT: outputs: [ 2 ], | ||
// CHECK-NEXT: operators: [ { | ||
// CHECK-NEXT: inputs: [ 0 ], | ||
// CHECK-NEXT: outputs: [ 1 ] | ||
// CHECK-NEXT: }, { | ||
// CHECK-NEXT: inputs: [ 1 ], | ||
// CHECK-NEXT: outputs: [ 2 ] | ||
// CHECK-NEXT: } ], | ||
// CHECK-NEXT: name: "main" | ||
// CHECK-NEXT: } ], | ||
// CHECK-NEXT: description: "MLIR Converted.", | ||
// CHECK-NEXT: buffers: [ { | ||
// CHECK-EMPTY: | ||
// CHECK-NEXT: }, { | ||
// CHECK-EMPTY: | ||
// CHECK-NEXT: }, { | ||
// CHECK-EMPTY: | ||
// CHECK-NEXT: }, { | ||
// CHECK-EMPTY: | ||
// CHECK-NEXT: }, { | ||
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] | ||
// CHECK-NEXT: } ], | ||
// CHECK-NEXT: metadata: [ { | ||
// CHECK-NEXT: name: "min_runtime_version", | ||
// CHECK-NEXT: buffer: 4 | ||
// CHECK-NEXT: } ], | ||
// CHECK-NEXT: signature_defs: [ ] | ||
// CHECK-NEXT: } | ||
|
||
%0 = "tfl.cast" (%arg0) : (tensor<4x5xbf16>) -> tensor<4x5xf32> loc("cast1") | ||
%1 = "tfl.cast" (%0) : (tensor<4x5xf32>) -> tensor<4x5xbf16> loc("cast2") | ||
func.return %1 : tensor<4x5xbf16> | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters