Skip to content

Commit

Permalink
Minor refactor: rename the 'lower bound batch threads' transform to a…
Browse files Browse the repository at this point in the history
… more generic 'reconfig batch op'. It makes no logical changes.

PiperOrigin-RevId: 636898956
  • Loading branch information
tensorflower-gardener committed May 28, 2024
1 parent ae7327d commit 471e2a6
Show file tree
Hide file tree
Showing 33 changed files with 378 additions and 922 deletions.
2 changes: 1 addition & 1 deletion tensorflow/compiler/mlir/tfrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -208,12 +208,12 @@ cc_library(
"transforms/deduplicate_if_result_pass.cc",
"transforms/fuse_tpu_compile_and_execute_ops.cc",
"transforms/insert_tensor_copy.cc",
"transforms/lower_bound_batch_threads.cc",
"transforms/lower_saved_model.cc",
"transforms/merge_tf_if_ops.cc",
"transforms/optimize.cc",
"transforms/optimize_tf_control_flow_side_effect.cc",
"transforms/passes.cc",
"transforms/reconfig_batch_op.cc",
"transforms/remove_device_attribute.cc",
"transforms/remove_tf_if_const_args.cc",
"transforms/reorder_assert.cc",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: tf-tfrt-opt -split-input-file -tfrt-lower-bound-batch-threads="tfrt-min-num-batch-threads=2" %s | FileCheck %s --dump-input=always
// RUN: tf-tfrt-opt -split-input-file -tfrt-reconfig-batch-op="tfrt-min-num-batch-threads=2" %s | FileCheck %s --dump-input=always

// -----

Expand Down
4 changes: 2 additions & 2 deletions tensorflow/compiler/mlir/tfrt/transforms/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ void CreateTFExecutorToTFPreInvariantOptimizationPipelineHelper(
pm.addPass(tfrt_compiler::CreateMergeTfIfOpsPass());

// Lower bound on the number of batch threads in `tf.BatchFunction`.
pm.addPass(tfrt_compiler::CreateLowerBoundBatchThreadsPass(
options.min_num_batch_threads));
pm.addPass(tfrt_compiler::CreateReconfigBatchOpPass(
{.min_num_batch_threads = options.min_num_batch_threads}));

// Deduplicate functions invoked by tf.BatchFunction with the same
// shared_name
Expand Down
8 changes: 6 additions & 2 deletions tensorflow/compiler/mlir/tfrt/transforms/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_PASSES_H_
#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_PASSES_H_

#include <cstdint>
#include <memory>

#include "llvm/Support/CommandLine.h"
Expand Down Expand Up @@ -67,8 +68,11 @@ std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDeduplicateFunctionsInovkedByBatchFunctionPass();

// Create a pass to lower bound the number of threads in tf.BatchFunction.
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateLowerBoundBatchThreadsPass(int64_t min_num_batch_threads);
struct ReconfigBatchOpPassOptions {
int64_t min_num_batch_threads = 1;
};
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> CreateReconfigBatchOpPass(
ReconfigBatchOpPassOptions options);

// Create a pass to fuse the TPU Ops for TFRT.
std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,34 +31,31 @@ namespace tensorflow {
namespace tfrt_compiler {
namespace {

class LowerBoundBatchThreadsPass
: public mlir::PassWrapper<LowerBoundBatchThreadsPass,
class ReconfigBatchOpPass
: public mlir::PassWrapper<ReconfigBatchOpPass,
mlir::OperationPass<mlir::ModuleOp>> {
public:
explicit LowerBoundBatchThreadsPass(uint64_t min_num_batch_threads)
: mlir::PassWrapper<LowerBoundBatchThreadsPass,
explicit ReconfigBatchOpPass(ReconfigBatchOpPassOptions options)
: mlir::PassWrapper<ReconfigBatchOpPass,
mlir::OperationPass<mlir::ModuleOp>>() {
min_num_batch_threads_ = min_num_batch_threads;
min_num_batch_threads_ = options.min_num_batch_threads;
}
LowerBoundBatchThreadsPass()
: mlir::PassWrapper<LowerBoundBatchThreadsPass,
ReconfigBatchOpPass()
: mlir::PassWrapper<ReconfigBatchOpPass,
mlir::OperationPass<mlir::ModuleOp>>() {}
LowerBoundBatchThreadsPass(const LowerBoundBatchThreadsPass& other)
: mlir::PassWrapper<LowerBoundBatchThreadsPass,
ReconfigBatchOpPass(const ReconfigBatchOpPass& other)
: mlir::PassWrapper<ReconfigBatchOpPass,
mlir::OperationPass<mlir::ModuleOp>>(other) {}

LowerBoundBatchThreadsPass& operator=(
const LowerBoundBatchThreadsPass& other) = delete;
ReconfigBatchOpPass& operator=(const ReconfigBatchOpPass& other) = delete;

MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerBoundBatchThreadsPass)
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ReconfigBatchOpPass)

private:
llvm::StringRef getArgument() const final {
return "tfrt-lower-bound-batch-threads";
}
llvm::StringRef getArgument() const final { return "tfrt-reconfig-batch-op"; }

llvm::StringRef getDescription() const final {
return "Lower bound batch threads for batch ops.";
return "Reconfig batch op such as num_batch_threads.";
}

void runOnOperation() override {
Expand All @@ -82,12 +79,12 @@ class LowerBoundBatchThreadsPass

} // namespace

std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateLowerBoundBatchThreadsPass(int64_t min_num_batch_threads) {
return std::make_unique<LowerBoundBatchThreadsPass>(min_num_batch_threads);
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> CreateReconfigBatchOpPass(
ReconfigBatchOpPassOptions options) {
return std::make_unique<ReconfigBatchOpPass>(options);
}

static mlir::PassRegistration<LowerBoundBatchThreadsPass> register_pass;
static mlir::PassRegistration<ReconfigBatchOpPass> register_pass;

} // namespace tfrt_compiler
} // namespace tensorflow
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Dockerfile to build a manylinux 2010 compliant cross-compiler.
#
# Builds a devtoolset gcc/libstdc++ that targets manylinux 2010 compatible
# glibc (2.12) and system libstdc++ (4.4).
#
# To push a new version, run:
# $ docker build -f Dockerfile.rbe.cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython \
# --tag "gcr.io/tensorflow-testing/nosla-cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython" .
# $ docker push gcr.io/tensorflow-testing/nosla-cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython

FROM gcr.io/tensorflow-sigs/build@sha256:3573cdabdea7f203b6440a93bb50a0e1a17c2c9a33f09fccdc0c97f514f9689c

ENV DEBIAN_FRONTEND=noninteractive

COPY install/install_bootstrap_deb_packages.sh /install/
RUN /install/install_bootstrap_deb_packages.sh

COPY install/install_deb_packages.sh /install/
RUN /install/install_deb_packages.sh

RUN apt-get update && apt-get install -y \
libbz2-dev \
libffi-dev \
libgdbm-dev \
libncurses5-dev \
libnss3-dev \
libreadline-dev \
libsqlite3-dev \
patchelf \
libcudnn9-dev-cuda-12=9.1.1.17-1 \
libcudnn9-cuda-12=9.1.1.17-1 \
&& \
rm -rf /var/lib/apt/lists/*

COPY install/build_and_install_python.sh /install/
RUN /install/build_and_install_python.sh "3.9.18"
RUN /install/build_and_install_python.sh "3.10.13"
RUN /install/build_and_install_python.sh "3.11.6"
RUN /install/build_and_install_python.sh "3.12.2"

COPY install/install_pip_packages_by_version.sh /install/
# https://github.com/numpy/numpy/issues/22623 for `SETUPTOOLS_USE_DISTUTILS`.
RUN SETUPTOOLS_USE_DISTUTILS=stdlib /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.9" "jax"
RUN SETUPTOOLS_USE_DISTUTILS=stdlib /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.10" "jax"
RUN SETUPTOOLS_USE_DISTUTILS=stdlib /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.11" "jax"
RUN SETUPTOOLS_USE_DISTUTILS=stdlib /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.12" "jax"
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Dockerfile to build a manylinux 2010 compliant cross-compiler.
#
# Builds a devtoolset gcc/libstdc++ that targets manylinux 2010 compatible
# glibc (2.12) and system libstdc++ (4.4).
#
# To push a new version, run:
# $ docker build -f Dockerfile.rbe.cuda12.2-cudnn9.1-ubuntu20.04-manylinux2014-multipython \
# --tag "gcr.io/tensorflow-testing/nosla-cuda12.2-cudnn9.1-ubuntu20.04-manylinux2014-multipython" .
# $ docker push gcr.io/tensorflow-testing/nosla-cuda12.2-cudnn9.1-ubuntu20.04-manylinux2014-multipython

FROM gcr.io/tensorflow-sigs/build@sha256:7c8ecb6482e26c4b4efce0ddaefe3fb3667b3b958c83fe8d3cc3763c6ed7a4d1

ENV DEBIAN_FRONTEND=noninteractive

COPY install/install_bootstrap_deb_packages.sh /install/
RUN /install/install_bootstrap_deb_packages.sh

COPY install/install_deb_packages.sh /install/
RUN /install/install_deb_packages.sh

RUN apt-get update && apt-get install -y \
libbz2-dev \
libffi-dev \
libgdbm-dev \
libncurses5-dev \
libnss3-dev \
libreadline-dev \
libsqlite3-dev \
patchelf \
libcudnn9-dev-cuda-12=9.1.1.17-1 \
libcudnn9-cuda-12=9.1.1.17-1 \
&& \
rm -rf /var/lib/apt/lists/*

COPY install/build_and_install_python.sh /install/
RUN /install/build_and_install_python.sh "3.9.18"
RUN /install/build_and_install_python.sh "3.10.13"
RUN /install/build_and_install_python.sh "3.11.6"
RUN /install/build_and_install_python.sh "3.12.0"

COPY install/install_pip_packages_by_version.sh /install/
# https://github.com/numpy/numpy/issues/22623 for `SETUPTOOLS_USE_DISTUTILS`.
RUN SETUPTOOLS_USE_DISTUTILS=stdlib /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.9" "jax"
RUN SETUPTOOLS_USE_DISTUTILS=stdlib /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.10" "jax"
RUN SETUPTOOLS_USE_DISTUTILS=stdlib /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.11" "jax"
RUN SETUPTOOLS_USE_DISTUTILS=stdlib /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.12" "jax"
16 changes: 16 additions & 0 deletions tensorflow/tools/toolchains/remote_config/containers.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ container_digests = {
# JAX manylinux2014 configs.
"cuda11.8-cudnn8.6-ubuntu20.04-manylinux2014-multipython": "sha256:45619e91f14faabddd79fe0cb1526df4c4ad92fc2e6ebdc725ea4419225429c3",
"cuda12.1-cudnn8.9-ubuntu20.04-manylinux2014-multipython": "sha256:8c266e5b0acd203aed5e8871b63f68a39d8d23f6d882e619797e58b973f7fe63",
"cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython": "sha256:7822b47867ecfc1f57df1cfadeaf091b72191d94cb722c271ed38809be7e7a61",
"cuda12.2-cudnn8.9-ubuntu20.04-manylinux2014-multipython": "sha256:9fefda035b4a12b24cd5bae56c7dbb9527a5fd06a41ced0a22ac86fe5ed26428",
"cuda12.2-cudnn9.1-ubuntu20.04-manylinux2014-multipython": "sha256:0c78f3428cde36f041b758fc2f01d23d2f0dd72dec248f78667fb0c9d1f74cef",
"cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython": "sha256:6f9524a2ed7f75255dc4be3a0c5e3bda581385a1c13e2fa890bc17fa62da95b2",
"cuda12.3-cudnn8.9-ubuntu22.04-manylinux2014-multipython": "sha256:dddcaf30321e9007103dce75c51b83fea3c06de462fcf41e7c6ae93f37fc3545",
"cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython": "sha256:7128b33e8f22d5f5ec9640bc377c3afddf7eb31daa1f958d1dd91dd7fda6a790",
Expand Down Expand Up @@ -102,13 +104,27 @@ containers = {
"digest": container_digests["cuda12.1-cudnn8.9-ubuntu20.04-manylinux2014-multipython"],
},

# Built with //tensorflow/tools/ci_build/Dockerfile.rbe.cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython.
"cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython": {
"registry": "gcr.io",
"repository": "tensorflow-testing/nosla-cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython",
"digest": container_digests["cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython"],
},

# Built with //tensorflow/tools/ci_build/Dockerfile.rbe.cuda12.2-cudnn8.9-ubuntu20.04-manylinux2014-multipython.
"cuda12.2-cudnn8.9-ubuntu20.04-manylinux2014-multipython": {
"registry": "gcr.io",
"repository": "tensorflow-testing/nosla-cuda12.2-cudnn8.9-ubuntu20.04-manylinux2014-multipython",
"digest": container_digests["cuda12.2-cudnn8.9-ubuntu20.04-manylinux2014-multipython"],
},

# Built with //tensorflow/tools/ci_build/Dockerfile.rbe.cuda12.2-cudnn9.1-ubuntu20.04-manylinux2014-multipython.
"cuda12.2-cudnn9.1-ubuntu20.04-manylinux2014-multipython": {
"registry": "gcr.io",
"repository": "tensorflow-testing/nosla-cuda12.2-cudnn9.1-ubuntu20.04-manylinux2014-multipython",
"digest": container_digests["cuda12.2-cudnn9.1-ubuntu20.04-manylinux2014-multipython"],
},

# Built with //tensorflow/tools/ci_build/Dockerfile.rbe.cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython.
"cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython": {
"registry": "gcr.io",
Expand Down
67 changes: 0 additions & 67 deletions third_party/triton/temporary/enable_mma_v3.patch

This file was deleted.

14 changes: 0 additions & 14 deletions third_party/triton/temporary/exclude_failing_h100_tests.patch

This file was deleted.

35 changes: 35 additions & 0 deletions third_party/triton/temporary/fp8_splat_partial_revert.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
This can be deleted on the next integrate; it is a revert of a previous patch.
diff --git a/include/triton/Conversion/MLIRTypes.h b/include/triton/Conversion/MLIRTypes.h
--- a/include/triton/Conversion/MLIRTypes.h
+++ b/include/triton/Conversion/MLIRTypes.h
@@ -26,6 +26,15 @@ inline Type f32Ty(MLIRContext *ctx) { re
inline Type f64Ty(MLIRContext *ctx) { return FloatType::getF64(ctx); }
inline Type bf16Ty(MLIRContext *ctx) { return FloatType::getBF16(ctx); }

+inline bool isFloat(Type type) {
+ return type.isF32() || type.isF64() || type.isF16() || type.isF128() ||
+ type.isBF16() || type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3FN() ||
+ type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() ||
+ type.isFloat8E5M2FNUZ();
+}
+
+inline bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); }
+
} // namespace type
} // namespace triton
} // namespace mlir
diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
--- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
+++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
@@ -74,9 +74,9 @@ struct ArithConstantSplatOpConversion
auto values = mlir::dyn_cast<SplatElementsAttr>(op.getValue());
auto elemType = values.getElementType();
Attribute val;
- if (isa<FloatType>(elemType)) {
+ if (type::isFloat(elemType)) {
val = values.getValues<FloatAttr>()[0];
- } else if (isa<IntegerType>(elemType)) {
+ } else if (type::isInt(elemType)) {
val = values.getValues<IntegerAttr>()[0];
} else {
llvm::errs() << "ArithConstantSplatOpConversion get unsupported type: "
Loading

0 comments on commit 471e2a6

Please sign in to comment.