Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial integration of ZenDNN as backend into PyTorch #76242

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,9 @@
[submodule "third_party/nlohmann"]
path = third_party/nlohmann
url = https://github.com/nlohmann/json.git
[submodule "third_party/blis"]
path = third_party/blis
url = https://github.com/amd/blis
[submodule "third_party/ZenDNN"]
path = third_party/ZenDNN
url = https://github.com/amd/ZenDNN
8 changes: 8 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ generated_cpu_cpp = [
# "aten/src/ATen/RegisterFunctionalizationEverything.cpp",
"aten/src/ATen/RegisterMkldnnCPU.cpp",
"aten/src/ATen/RegisterNestedTensorCPU.cpp",
"aten/src/ATen/RegisterZendnnCPU.cpp",
"aten/src/ATen/RegisterQuantizedCPU.cpp",
"aten/src/ATen/RegisterSparseCPU.cpp",
"aten/src/ATen/RegisterSparseCsrCPU.cpp",
Expand Down Expand Up @@ -183,6 +184,11 @@ filegroup(
srcs = glob(["aten/src/ATen/native/mkldnn/*.cpp"]),
)

filegroup(
name = "aten_native_zendnn_cpp",
srcs = glob(["aten/src/ATen/native/zendnn/*.cpp"]),
)

filegroup(
name = "aten_native_xnnpack",
srcs = glob(["aten/src/ATen/native/xnnpack/*.cpp"]),
Expand Down Expand Up @@ -252,6 +258,7 @@ header_template_rule(
include = "aten/src",
substitutions = {
"@AT_MKLDNN_ENABLED@": "1",
"@AT_ZENDNN_ENABLED@": "0",
"@AT_MKL_ENABLED@": "1",
"@AT_MKL_SEQUENTIAL@": "0",
"@AT_FFTW_ENABLED@": "0",
Expand Down Expand Up @@ -336,6 +343,7 @@ cc_library(
":aten_native_cpp",
":aten_native_mkl_cpp",
":aten_native_mkldnn_cpp",
":aten_native_zendnn_cpp",
":aten_native_quantized_cpp",
":aten_native_sparse_cpp",
":aten_native_nested_cpp",
Expand Down
26 changes: 22 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,10 @@ if(APPLE)
endif()

set(CPU_AARCH64 OFF)
set(CPU_INTEL OFF)
set(CPU_X64 OFF)

if(CMAKE_SYSTEM_PROCESSOR MATCHES "(AMD64|x86_64)")
set(CPU_INTEL ON)
set(CPU_X64 ON)
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm64)")
set(CPU_AARCH64 ON)
endif()
Expand Down Expand Up @@ -301,9 +301,16 @@ cmake_dependent_option(
"CPU_INTEL" OFF)
# Ensure that an MKLDNN build is the default for x86 CPUs
# but optional for AArch64 (dependent on -DUSE_MKLDNN).

IF(USE_ZENDNN AND USE_MKLDNN)
MESSAGE(FATAL_ERROR "Enable only either ZenDNN or MKLDNN as a Backend")
RETURN()
ENDIF(USE_ZENDNN AND USE_MKLDNN)

cmake_dependent_option(
USE_MKLDNN "Use MKLDNN. Only available on x86, x86_64, and AArch64." "${CPU_INTEL}"
"CPU_INTEL OR CPU_AARCH64" OFF)
USE_MKLDNN "Use MKLDNN. Only available on x86, x86_64, and AArch64." "${CPU_X64}"
"CPU_X64 OR CPU_AARCH64;NOT USE_ZENDNN" OFF)

cmake_dependent_option(
USE_MKLDNN_ACL "Use Compute Library for the Arm architecture." OFF
"USE_MKLDNN AND CPU_AARCH64" OFF)
Expand All @@ -312,6 +319,16 @@ cmake_dependent_option(
USE_MKLDNN_CBLAS "Use CBLAS in MKLDNN" OFF
"USE_MKLDNN" OFF)
option(USE_STATIC_MKL "Prefer to link with MKL statically (Unix only)" OFF)


cmake_dependent_option(
USE_ZENDNN "Use ZENDNN. Only available on x86 and x86_64." "${CPU_X64}"
"CPU_X64;USE_ZENDNN" OFF)
set(ZENDNN_ENABLE_CONCURRENT_EXEC ${USE_ZENDNN})
cmake_dependent_option(
USE_ZENDNN_CBLAS "Use CBLAS in ZENDNN" OFF
"USE_ZENDNN" OFF)

option(USE_DISTRIBUTED "Use distributed" ON)
cmake_dependent_option(
USE_MPI "Use MPI for Caffe2. Only available if USE_DISTRIBUTED is on." ON
Expand Down Expand Up @@ -1099,6 +1116,7 @@ if(BUILD_SHARED_LIBS)
${PROJECT_SOURCE_DIR}/cmake/public/gflags.cmake
${PROJECT_SOURCE_DIR}/cmake/public/mkl.cmake
${PROJECT_SOURCE_DIR}/cmake/public/mkldnn.cmake
${PROJECT_SOURCE_DIR}/cmake/public/zendnn.cmake
${PROJECT_SOURCE_DIR}/cmake/public/protobuf.cmake
${PROJECT_SOURCE_DIR}/cmake/public/threads.cmake
${PROJECT_SOURCE_DIR}/cmake/public/utils.cmake
Expand Down
7 changes: 6 additions & 1 deletion aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ file(GLOB mkldnn_cpp "mkldnn/*.cpp")
file(GLOB native_cpp "native/*.cpp")
file(GLOB native_mkl_cpp "native/mkl/*.cpp")
file(GLOB native_mkldnn_cpp "native/mkldnn/*.cpp")
file(GLOB native_zendnn_cpp "native/zendnn/*.cpp")
file(GLOB vulkan_cpp "vulkan/*.cpp")
file(GLOB native_vulkan_cpp "native/vulkan/*.cpp" "native/vulkan/api/*.cpp" "native/vulkan/ops/*.cpp")

Expand Down Expand Up @@ -176,7 +177,7 @@ else()
all_cpu_cpp ${base_cpp} ${ATen_CORE_SRCS} ${native_cpp}
${native_ao_sparse_cpp} ${native_sparse_cpp} ${native_nested_cpp}
${native_quantized_cpp} ${native_mkl_cpp} ${native_mkldnn_cpp}
${native_transformers_cpp}
${native_zendnn_cpp} ${native_transformers_cpp}
${native_utils_cpp} ${native_xnnpack} ${generated_sources} ${core_generated_sources}
${ATen_CPU_SRCS} ${ATen_QUANTIZED_SRCS} ${ATen_NNAPI_SRCS} ${cpu_kernel_cpp}
)
Expand Down Expand Up @@ -346,6 +347,10 @@ if(MKLDNN_FOUND)
list(APPEND ATen_CPU_DEPENDENCY_LIBS ${MKLDNN_LIBRARIES})
endif(MKLDNN_FOUND)

if(ZENDNN_FOUND)
list(APPEND ATen_CPU_DEPENDENCY_LIBS ${ZENDNN_LIBRARIES})
endif(ZENDNN_FOUND)

list(APPEND ATen_CPU_DEPENDENCY_LIBS cpuinfo)

if(NOT MSVC AND NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE)
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/Config.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
// DO NOT put the macros for CUDA libraries in this file; they belong in cuda/CUDAConfig.h

#define AT_MKLDNN_ENABLED() @AT_MKLDNN_ENABLED@
#define AT_ZENDNN_ENABLED() @AT_ZENDNN_ENABLED@
#define AT_MKL_ENABLED() @AT_MKL_ENABLED@
#define AT_MKL_SEQUENTIAL() @AT_MKL_SEQUENTIAL@
#define AT_FFTW_ENABLED() @AT_FFTW_ENABLED@
Expand Down
17 changes: 17 additions & 0 deletions aten/src/ATen/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,19 @@ bool Context::userEnabledMkldnn() const {
return enabled_mkldnn;
}

bool Context::userEnabledZendnn() const {
return enabled_zendnn;
}

void Context::setUserEnabledMkldnn(bool e) {
enabled_mkldnn = e;
}

void Context::setUserEnabledZendnn(bool e) {
enabled_zendnn = e;
}


bool Context::deterministicCuDNN() const {
return deterministic_cudnn;
}
Expand Down Expand Up @@ -236,6 +245,14 @@ bool Context::hasMKLDNN() {
#endif
}

bool Context::hasZENDNN() {
#if AT_ZENDNN_ENABLED()
return true;
#else
return false;
#endif
}

bool Context::hasMPS() {
#if USE_MPS
return at::mps::is_available();
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class TORCH_API Context {
static bool hasMKL();
static bool hasLAPACK();
static bool hasMKLDNN();
static bool hasZENDNN();
static bool hasMAGMA() {
return detail::getCUDAHooks().hasMAGMA();
}
Expand Down Expand Up @@ -119,6 +120,8 @@ class TORCH_API Context {
void setUserEnabledCuDNN(bool e);
bool userEnabledMkldnn() const;
void setUserEnabledMkldnn(bool e);
bool userEnabledZendnn() const;
void setUserEnabledZendnn(bool e);
bool benchmarkCuDNN() const;
void setBenchmarkCuDNN(bool);
int benchmarkLimitCuDNN() const;
Expand Down Expand Up @@ -260,6 +263,7 @@ class TORCH_API Context {
bool allow_tf32_cudnn = true;
bool allow_fp16_reduction_cublas = true;
bool enabled_mkldnn = true;
bool enabled_zendnn = true;
at::LinalgBackend linalg_preferred_backend = at::LinalgBackend::Default;
#ifdef C10_MOBILE
bool release_original_weights = true;
Expand Down Expand Up @@ -371,6 +375,10 @@ static inline bool hasMKLDNN() {
return globalContext().hasMKLDNN();
}

static inline bool hasZENDNN() {
return globalContext().hasZENDNN();
}

static inline void manual_seed(uint64_t seed) {
auto gen = globalContext().defaultGenerator(DeviceType::CPU);
{
Expand Down
5 changes: 4 additions & 1 deletion aten/src/ATen/core/Formatting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,10 @@ std::ostream& print(std::ostream& stream, const Tensor & tensor_, int64_t linesi
} else if (tensor_.is_mkldnn()) {
stream << "MKLDNN Tensor: ";
tensor = tensor_.to_dense().to(kCPU, kDouble).contiguous();
} else {
} else if (tensor_.is_zendnn()) {
stream << "ZENDNN Tensor: ";
tensor = tensor_.to_dense().to(kCPU, kDouble).contiguous();
}else {
tensor = tensor_.to(kCPU, kDouble).contiguous();
}
if(tensor.ndimension() == 0) {
Expand Down
6 changes: 6 additions & 0 deletions aten/src/ATen/core/TensorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,12 @@ class TORCH_API TensorBase {
return impl_->is_mkldnn();
}

/// Returns if a `Tensor` is zendnn tensor.
bool is_zendnn() const {
// NB: this is not a native function to avoid dispatching overhead.
return impl_->is_zendnn();
}

/// Returns if a `Tensor` is mps tensor.
bool is_mps() const {
// NB: this is not a native function to avoid dispatching overhead.
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ namespace c10 {
_(prim, MKLDNNHardSigmoid) \
_(prim, MKLDNNHardTanh) \
_(prim, MKLDNNClamp) \
_(prim, zendnn_convolution) \
_(prim, StaticRuntimeCopyOuts) \
_(prim, Drop) \
_(prim, Eval) \
Expand Down
7 changes: 7 additions & 0 deletions aten/src/ATen/native/ConvUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ using mkldnn_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tens
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
at::IntArrayRef, int64_t, std::array<bool,3>);
DECLARE_DISPATCH(mkldnn_convolution_backward_fn, mkldnn_convolution_backward_stub);
using zendnn_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
at::IntArrayRef, int64_t, std::array<bool,3>);
DECLARE_DISPATCH(zendnn_convolution_backward_fn, zendnn_convolution_backward_stub);
using slow_conv_dilated2d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
at::IntArrayRef, at::IntArrayRef, std::array<bool, 3>);
Expand Down Expand Up @@ -107,6 +111,7 @@ struct ConvParams {
bool use_cudnn_depthwise(const at::Tensor& input, const at::Tensor& weight) const;
bool use_miopen(const at::Tensor& input, const at::Tensor& weight, bool bias_defined) const;
bool use_mkldnn(const at::Tensor& input, const at::Tensor& weight) const;
bool use_zendnn(const at::Tensor& input, const at::Tensor& weight) const;
bool use_nnpack(const at::Tensor& input, const at::Tensor& weight) const;
bool use_xnnpack(const at::Tensor& input, const at::Tensor& weight,
const at::OptionalIntArrayRef bias_sizes_opt) const;
Expand All @@ -125,6 +130,8 @@ enum class ConvBackend {
MiopenTranspose,
Mkldnn,
MkldnnEmpty,
Zendnn,
ZendnnEmpty,
NnpackSpatial,
Overrideable,
Slow2d,
Expand Down