diff --git a/.circleci/scripts/binary_ios_upload.sh b/.circleci/scripts/binary_ios_upload.sh index c36761b8ce58..4c658462aa0d 100644 --- a/.circleci/scripts/binary_ios_upload.sh +++ b/.circleci/scripts/binary_ios_upload.sh @@ -14,7 +14,7 @@ mkdir -p ${ZIP_DIR}/src cp -R ${ARTIFACTS_DIR}/arm64/include ${ZIP_DIR}/install/ # build a FAT bianry cd ${ZIP_DIR}/install/lib -target_libs=(libc10.a libclog.a libcpuinfo.a libeigen_blas.a libpytorch_qnnpack.a libtorch_cpu.a libtorch.a) +target_libs=(libc10.a libclog.a libcpuinfo.a libeigen_blas.a libpytorch_qnnpack.a libtorch_cpu.a libtorch.a libXNNPACK.a) for lib in ${target_libs[*]} do if [ -f "${ARTIFACTS_DIR}/x86_64/lib/${lib}" ] && [ -f "${ARTIFACTS_DIR}/arm64/lib/${lib}" ]; then diff --git a/.gitmodules b/.gitmodules index 228e123cc00b..599b1df431f2 100644 --- a/.gitmodules +++ b/.gitmodules @@ -111,10 +111,14 @@ path = third_party/foxi url = https://github.com/houseroad/foxi.git [submodule "third_party/tbb"] - path = third_party/tbb - url = https://github.com/01org/tbb - branch = tbb_2018 + path = third_party/tbb + url = https://github.com/01org/tbb + branch = tbb_2018 [submodule "android/libs/fbjni"] ignore = dirty path = android/libs/fbjni url = https://github.com/facebookincubator/fbjni.git +[submodule "third_party/XNNPACK"] + path = third_party/XNNPACK + url = https://github.com/AshkanAliabadi/XNNPACK.git + branch = xnnpack_pytorch_merge_temp diff --git a/CMakeLists.txt b/CMakeLists.txt index efae635669f1..67849fb51210 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -185,6 +185,7 @@ option(USE_SNPE "Use Qualcomm's SNPE library" OFF) option(USE_SYSTEM_EIGEN_INSTALL "Use system Eigen instead of the one under third_party" OFF) option(USE_TENSORRT "Using Nvidia TensorRT library" OFF) +option(USE_XNNPACK "Use XNNPACK" ON) option(USE_ZMQ "Use ZMQ" OFF) option(USE_ZSTD "Use ZSTD" OFF) cmake_dependent_option( @@ -416,6 +417,10 @@ if(USE_PYTORCH_QNNPACK) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_PYTORCH_QNNPACK") endif() +if(USE_XNNPACK) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_XNNPACK") +endif() + # ---[ Whitelist file if whitelist is specified include(cmake/Whitelist.cmake) diff --git a/android/pytorch_android/CMakeLists.txt b/android/pytorch_android/CMakeLists.txt index a00a3845c432..b83413df0a07 100644 --- a/android/pytorch_android/CMakeLists.txt +++ b/android/pytorch_android/CMakeLists.txt @@ -83,6 +83,7 @@ if (ANDROID_ABI) import_static_lib(libtorch_cpu) import_static_lib(libc10) import_static_lib(libnnpack) + import_static_lib(libXNNPACK) import_static_lib(libpytorch_qnnpack) import_static_lib(libeigen_blas) import_static_lib(libcpuinfo) @@ -98,6 +99,7 @@ if (ANDROID_ABI) -Wl,--no-whole-archive libc10 libnnpack + libXNNPACK libpytorch_qnnpack libeigen_blas libcpuinfo @@ -113,6 +115,7 @@ else() torch_cpu c10 nnpack + XNNPACK pytorch_qnnpack cpuinfo clog diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index d6ee1796ce02..3d45042bffb2 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -84,8 +84,11 @@ FILE(GLOB native_sparse_hip_cpp "native/sparse/hip/*.cpp") FILE(GLOB native_quantized_hip_hip "native/quantized/hip/*.hip") FILE(GLOB native_quantized_hip_cpp "native/quantized/hip/*.cpp") +# XNNPACK +FILE(GLOB native_xnnpack "native/xnnpack/*.cpp") + add_subdirectory(quantized) -set(all_cpu_cpp ${base_cpp} ${ATen_CORE_SRCS} ${native_cpp} ${native_sparse_cpp} ${native_quantized_cpp} ${native_mkl_cpp} ${native_mkldnn_cpp} ${generated_cpp} ${core_generated_cpp} ${ATen_CPU_SRCS} ${ATen_QUANTIZED_SRCS} ${cpu_kernel_cpp}) +set(all_cpu_cpp ${base_cpp} ${ATen_CORE_SRCS} ${native_cpp} ${native_sparse_cpp} ${native_quantized_cpp} ${native_mkl_cpp} ${native_mkldnn_cpp} ${native_xnnpack} ${generated_cpp} ${core_generated_cpp} ${ATen_CPU_SRCS} ${ATen_QUANTIZED_SRCS} ${cpu_kernel_cpp}) if(AT_MKL_ENABLED) set(all_cpu_cpp ${all_cpu_cpp} ${mkl_cpp}) endif() diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index b0cc2bcaada1..8c648410087b 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -775,6 +775,10 @@ - func: conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor +- func: _conv2d_prepack(Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1, float? output_min=None, float? output_max=None) -> Tensor + +- func: _conv2d_packed(Tensor packed_weight, Tensor input) -> Tensor + - func: conv3d(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1, int groups=1) -> Tensor - func: conv_tbc(Tensor self, Tensor weight, Tensor bias, int pad=0) -> Tensor @@ -1575,6 +1579,10 @@ - func: linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor python_module: nn +- func: _linear_prepack(Tensor weight, Tensor? bias=None, float? output_min=None, float? output_max=None) -> Tensor + +- func: _linear_packed(Tensor packed_weight, Tensor input) -> Tensor + - func: mkldnn_linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor python_module: nn dispatch: diff --git a/aten/src/ATen/native/utils/Allocator.h b/aten/src/ATen/native/utils/Allocator.h new file mode 100644 index 000000000000..af1a2df33959 --- /dev/null +++ b/aten/src/ATen/native/utils/Allocator.h @@ -0,0 +1,55 @@ +#pragma once + +#include + +namespace at { +namespace native { + +// QNNPACK AND XNNPACK may out-of-bound access the input and / or output tensors. +// This behavior will trigger ASAN, and may result in a segfault if the accessed +// memory just so happens to fall on a page the current process has no read access +// to. Here we define a custom allocator that allocates the extra storage required +// to keep this behavior safe. +// +// PreGuardBytes: Number of guard bytes to allocate before the allocation. +// PostGuardBytes: Number of guard bytes to allocate after the allocation. + +template +class GuardingAllocator final : public at::Allocator { + public: + GuardingAllocator() = default; + virtual ~GuardingAllocator() override = default; + + static void deleter(void* pointer) { + const Cast memory{pointer}; + c10::free_cpu(memory.as_byte_ptr - kPreGuardBytes); + } + + virtual DataPtr allocate(size_t nbytes) const override { + Cast memory{c10::alloc_cpu(kPreGuardBytes + nbytes + kPostGuardBytes)}; + memory.as_byte_ptr += kPreGuardBytes; + + return { + memory.as_void_ptr, + memory.as_void_ptr, + &deleter, + at::Device(DeviceType::CPU), + }; + } + + virtual DeleterFnPtr raw_deleter() const override { + return deleter; + } + + private: + static constexpr uint32_t kPreGuardBytes = PreGuardBytes; + static constexpr uint32_t kPostGuardBytes = PostGuardBytes; + + union Cast final { + void * const as_void_ptr; + uint8_t * as_byte_ptr; + }; +}; + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/xnnpack/Common.h b/aten/src/ATen/native/xnnpack/Common.h new file mode 100644 index 000000000000..3d318e83a46a --- /dev/null +++ b/aten/src/ATen/native/xnnpack/Common.h @@ -0,0 +1,82 @@ +#pragma once + +#include + +#ifdef USE_XNNPACK + +#include + +namespace at { +namespace native { +namespace xnnpack { +namespace internal { + +struct Layout final { + // 4D Activation Maps + struct Activation4D final { + static constexpr size_t batch = 0u; + static constexpr size_t channels = 1u; + static constexpr size_t height = 2u; + static constexpr size_t width = 3u; + }; + + // ND Activation Maps + struct ActivationND final { + // Some operators may not be limited to 4 dimensional tensors. In that scenario, + // XNNPACK denotes that operator with an _nc suffix and expects all dimensions, + // except channels, to be flattened into one argument: batch_size. + static int64_t batch(const IntArrayRef tensor) { + if (C10_UNLIKELY(tensor.empty())) { + return -1; + } + + // Handle the case where batch size is zero. + int64_t batch = std::max(1, tensor[0]); + + for (size_t index = 1u; index < (tensor.size() - 1u); ++index) { + batch *= tensor[index]; + } + + return batch; + }; + + static int64_t channel(const IntArrayRef tensor) { + if (C10_UNLIKELY(tensor.empty())) { + return -1; + } + + return tensor.back(); + }; + }; + + // Convolution Filters + struct Filter final { + static constexpr size_t output = 0u; + static constexpr size_t input = 1u; + static constexpr size_t height = 2u; + static constexpr size_t width = 3u; + }; + + // Parameters (Pooling Kernels, Dilation, Padding, Stride, etc.) + struct Parameter final { + static constexpr size_t height = 0u; + static constexpr size_t width = 1u; + }; +}; + +struct Deleter final { + void operator()(const xnn_operator_t op) const { + xnn_delete_operator(op); + } +}; + +using Operator = std::unique_ptr; + +bool available(); + +} // namespace internal +} // namespace xnnpack +} // namespace native +} // namespace at + +#endif /* USE_XNNPACK */ diff --git a/aten/src/ATen/native/xnnpack/Convolution.cpp b/aten/src/ATen/native/xnnpack/Convolution.cpp new file mode 100644 index 000000000000..6c485ad20ff2 --- /dev/null +++ b/aten/src/ATen/native/xnnpack/Convolution.cpp @@ -0,0 +1,316 @@ +#ifdef USE_XNNPACK + +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace xnnpack { +namespace internal { +namespace convolution2d { + +struct Context final { + Operator convolution_op; + + std::vector weight_size; + std::vector padding; + std::vector stride; + std::vector dilation; + + static constexpr float kMin = -std::numeric_limits::infinity(); + static constexpr float kMax = std::numeric_limits::infinity(); +}; + +namespace { + +// Supports NHWC and NCHW FP32 convolutions with any valid +// - kernel size +// - padding +// - stride +// - dilation +// - grouping + +// TODO: Decouple and improve error handling and messages. +bool available( + const Tensor& weight, + const c10::optional& bias, + const IntArrayRef padding, + const IntArrayRef stride, + const IntArrayRef dilation, + const int64_t groups, + const float output_min, + const float output_max) { + // XNNPACK + return xnnpack::internal::available() && + // Weight + (4 == weight.ndimension()) && + (weight.size(Layout::Filter::height) > 0) && + (weight.size(Layout::Filter::width) > 0) && + (c10::DeviceType::CPU == weight.device().type()) && + (kFloat == weight.scalar_type()) && + // Bias + ((bias && bias->defined()) ? ((1 == bias->ndimension()) && + (c10::DeviceType::CPU == bias->device().type()) && + (kFloat == bias->scalar_type()) && + (weight.size(Layout::Filter::output)) == bias->size(0)) + : true) && + // Padding + (padding[Layout::Parameter::height] >= 0) && + (padding[Layout::Parameter::width] >= 0) && + // Stride + (stride[Layout::Parameter::height] > 0) && + (stride[Layout::Parameter::width] > 0) && + // Dilation + (dilation[Layout::Parameter::height] > 0) && + (dilation[Layout::Parameter::width] > 0) && + // Groups + (groups > 0) && + // Input + (weight.size(Layout::Filter::input) > 0) && + // Output + (weight.size(Layout::Filter::output) > 0) && + // Output - Groups + ((weight.size(Layout::Filter::output) % groups) == 0) && + // Output Min / Max + (output_max > output_min) && + true; +} + +Context create( + const Tensor& weight, + const c10::optional& bias, + const IntArrayRef padding_, + const IntArrayRef stride_, + const IntArrayRef dilation_, + const int64_t groups, + const float output_min, + const float output_max) { + const auto padding = expand_param_if_needed(padding_, "padding", 2); + const auto stride = expand_param_if_needed(stride_, "stride", 2); + const auto dilation = expand_param_if_needed(dilation_, "dilation", 2); + const Tensor weight_nhwc = weight.contiguous(MemoryFormat::ChannelsLast); + + TORCH_CHECK( + available( + weight_nhwc, + bias, + padding, + stride, + dilation, + groups, + output_min, + output_max), + "XNNPACK Convolution not available! " + "Reason: The provided (weight, bias, padding, stride, dilation, groups, output_min, output_max) " + "parameters are either invalid individually or their combination is not supported by XNNPACK."); + + xnn_operator_t convolution_op{}; + + const xnn_status create_status = xnn_create_convolution2d_nhwc_f32( + padding[Layout::Parameter::height], // input_padding_top + padding[Layout::Parameter::width], // input_padding_right + padding[Layout::Parameter::height], // input_padding_bottom + padding[Layout::Parameter::width], // input_padding_left + weight_nhwc.size(Layout::Filter::height), // kernel_height + weight_nhwc.size(Layout::Filter::width), // kernel_width + stride[Layout::Parameter::height], // subsampling_height + stride[Layout::Parameter::width], // subsampling_width + dilation[Layout::Parameter::height], // dilation_height + dilation[Layout::Parameter::width], // dilation_width + groups, // groups + weight_nhwc.size(Layout::Filter::input), // group_input_channels + weight_nhwc.size(Layout::Filter::output) / groups, // group_output_channels + weight_nhwc.size(Layout::Filter::input) * groups, // input_pixel_stride + weight_nhwc.size(Layout::Filter::output), // output_pixel_stride + weight_nhwc.data_ptr(), // kernel + (bias && bias->defined()) ? bias->data_ptr() : nullptr, // bias + output_min, // output_min + output_max, // output_max + 0u, // flags + &convolution_op); // operator + + TORCH_CHECK( + xnn_status_success == create_status, + "xnn_create_convolution2d_nhwc_f32 failed!"); + + return Context{ + Operator(convolution_op), + weight_nhwc.sizes().vec(), + padding, + stride, + dilation, + }; +} + +// TODO: Decouple and improve error handling and messages. +bool usable(const Tensor& input) { + // Input + return (4 == input.ndimension()) && + (c10::DeviceType::CPU == input.device().type()) && + (kFloat == input.scalar_type()) && + (input.size(Layout::Activation4D::batch) > 0) && + (input.size(Layout::Activation4D::channels) > 0) && + (input.size(Layout::Activation4D::height) > 0) && + (input.size(Layout::Activation4D::width) > 0) && + true; +} + +Tensor run( + const Context& context, + const Tensor& input) { + using namespace internal; + + const Tensor input_nhwc = input.contiguous(MemoryFormat::ChannelsLast); + + TORCH_CHECK( + usable(input_nhwc), + "XNNPACK Convolution not usable! " + "Reason: The provided input tensor is either invalid or unsupported by XNNPACK."); + + Tensor output = empty_with_tail_padding( + conv_output_size( + input_nhwc.sizes(), + context.weight_size, + context.padding, + context.stride, + context.dilation), + input_nhwc.options().dtype(), + MemoryFormat::ChannelsLast); + + const xnn_status setup_status = xnn_setup_convolution2d_nhwc_f32( + context.convolution_op.get(), // operator + input_nhwc.size(Layout::Activation4D::batch), // batch_size + input_nhwc.size(Layout::Activation4D::height), // input_height + input_nhwc.size(Layout::Activation4D::width), // input_width + input_nhwc.data_ptr(), // input + output.data_ptr(), // output + nullptr); // threadpool + + TORCH_CHECK( + xnn_status_success == setup_status, + "xnn_setup_convolution2d_nhwc_f32 failed!"); + + const xnn_status run_status = xnn_run_operator( + context.convolution_op.get(), // operator + nullptr); // threadpool + + TORCH_INTERNAL_ASSERT( + xnn_status_success == run_status, + "xnn_run_operator failed!"); + + return output.contiguous(input.suggest_memory_format()); +} + +Tensor create_and_run( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + const IntArrayRef padding, + const IntArrayRef stride, + const IntArrayRef dilation, + const int64_t groups, + const float output_min, + const float output_max) { + return run( + create( + weight, + bias, + padding, + stride, + dilation, + groups, + output_min, + output_max), + input); +} + +} // namespace +} // namespace convolution2d +} // namespace internal + +bool use_convolution2d( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + const IntArrayRef padding, + const IntArrayRef stride, + const IntArrayRef dilation, + const int64_t groups) { + return internal::convolution2d::available( + weight, + bias, + padding, + stride, + dilation, + groups, + internal::convolution2d::Context::kMin, + internal::convolution2d::Context::kMax) && + internal::convolution2d::usable(input); +} + +Tensor convolution2d( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + const IntArrayRef padding, + const IntArrayRef stride, + const IntArrayRef dilation, + const int64_t groups) { + return internal::convolution2d::create_and_run( + input, + weight, + bias, + padding, + stride, + dilation, + groups, + internal::convolution2d::Context::kMin, + internal::convolution2d::Context::kMax); +} + +} // namespace xnnpack + +at::Tensor _conv2d_prepack( + const Tensor& weight, + const Tensor& bias, + const IntArrayRef stride, + const IntArrayRef padding, + const IntArrayRef dilation, + const int64_t groups, + const c10::optional output_min, + const c10::optional output_max) { + return cpp_custom_type_hack::create( + std::make_unique( + xnnpack::internal::convolution2d::create( + weight, + bias, + padding.vec(), + stride.vec(), + dilation.vec(), + groups, + output_min ? *output_min : xnnpack::internal::convolution2d::Context::kMin, + output_max ? *output_max : xnnpack::internal::convolution2d::Context::kMax)), + weight.options()); +} + +at::Tensor _conv2d_packed( + const Tensor& packed_weight, + const Tensor& input) { + return xnnpack::internal::convolution2d::run( + cpp_custom_type_hack::cast(packed_weight), + input); +} + +} // namespace native +} // namespace at + +namespace caffe2 { + +CAFFE_KNOWN_TYPE(at::native::xnnpack::internal::convolution2d::Context); + +} // namespace caffe2 + +#endif /* USE_XNNPACK */ diff --git a/aten/src/ATen/native/xnnpack/Factory.cpp b/aten/src/ATen/native/xnnpack/Factory.cpp new file mode 100644 index 000000000000..dcc8c12c842a --- /dev/null +++ b/aten/src/ATen/native/xnnpack/Factory.cpp @@ -0,0 +1,38 @@ +#ifdef USE_XNNPACK + +#include +#include + +namespace at { +namespace native { +namespace xnnpack { +namespace internal { + +Tensor empty_with_tail_padding( + const IntArrayRef size, + const caffe2::TypeMeta dtype, + const c10::MemoryFormat memory_format) { + static GuardingAllocator<0u, XNN_EXTRA_BYTES> allocator; + + const int64_t nelements = prod_intlist(size); + + Tensor tensor( + c10::make_intrusive( + c10::Storage{ + dtype, + nelements, + allocator.allocate(nelements * dtype.itemsize()), + &allocator, + /*resizable=*/true, + }, + DispatchKeySet{DispatchKey::CPUTensorId})); + + return tensor.resize_(size, memory_format); +} + +} // namespace internal +} // namespace xnnpack +} // namespace native +} // namespace at + +#endif /* USE_XNNPACK */ diff --git a/aten/src/ATen/native/xnnpack/Factory.h b/aten/src/ATen/native/xnnpack/Factory.h new file mode 100644 index 000000000000..25293ae7f790 --- /dev/null +++ b/aten/src/ATen/native/xnnpack/Factory.h @@ -0,0 +1,25 @@ +#pragma once + +#include + +#ifdef USE_XNNPACK + +namespace at { +namespace native { +namespace xnnpack { +namespace internal { + +// TODO: Remove this function when at::native::empty() is modified to accept a +// custom memory allocator. + +at::Tensor empty_with_tail_padding( + IntArrayRef size, + const caffe2::TypeMeta dtype, + c10::MemoryFormat memory_format); + +} // namespace internal +} // namespace xnnpack +} // namespace native +} // namespace at + +#endif /* USE_XNNPACK */ diff --git a/aten/src/ATen/native/xnnpack/Init.cpp b/aten/src/ATen/native/xnnpack/Init.cpp new file mode 100644 index 000000000000..6c9aba1d0f96 --- /dev/null +++ b/aten/src/ATen/native/xnnpack/Init.cpp @@ -0,0 +1,63 @@ +#ifdef USE_XNNPACK + +#include + +namespace at { +namespace native { +namespace xnnpack { +namespace internal { +namespace { + +bool is_initialized_ = false; + +bool initialize() { + using namespace internal; + + // This implementation allows for retries. + if (!is_initialized_) { + const xnn_status status = xnn_initialize(nullptr); + is_initialized_ = (xnn_status_success == status); + + if (!is_initialized_) { + if (xnn_status_out_of_memory == status) { + TORCH_WARN_ONCE("Failed to initialize XNNPACK! Reason: Out of memory."); + } else if (xnn_status_unsupported_hardware == status) { + TORCH_WARN_ONCE("Failed to initialize XNNPACK! Reason: Unsupported hardware."); + } else { + TORCH_WARN_ONCE("Failed to initialize XNNPACK! Reason: Unknown error!"); + } + } + } + + return is_initialized_; +} + +bool deinitialize() { + using namespace internal; + + // This implementation allows for retries. + if (is_initialized_) { + const xnn_status status = xnn_deinitialize(); + is_initialized_ = !(xnn_status_success == status); + + if (is_initialized_) { + TORCH_WARN_ONCE("Failed to uninitialize XNNPACK! Reason: Unknown error!"); + } + } + + return !is_initialized_; +} + +} // namespace + +bool available() { + // Add extra conditions here that should disable mobile CPU impl at runtime in its totality. + return internal::initialize(); +} + +} // namespace internal +} // namespace xnnpack +} // namespace native +} // namespace at + +#endif /* USE_XNNPACK */ diff --git a/aten/src/ATen/native/xnnpack/Linear.cpp b/aten/src/ATen/native/xnnpack/Linear.cpp new file mode 100644 index 000000000000..fa31a1430ce3 --- /dev/null +++ b/aten/src/ATen/native/xnnpack/Linear.cpp @@ -0,0 +1,223 @@ +#ifdef USE_XNNPACK + +#include +#include +#include + +namespace at { +namespace native { +namespace xnnpack { +namespace internal { +namespace linear { + +struct Context final { + Operator linear_op; + + struct Output final { + int64_t channels; + } output; + + static constexpr float kMin = -std::numeric_limits::infinity(); + static constexpr float kMax = std::numeric_limits::infinity(); +}; + +namespace { + +// Supports NHWC and NCHW FP32 linear operators. + +// TODO: Decouple and improve error handling and messages. +bool available( + const Tensor& weight, + const c10::optional& bias, + const float output_min, + const float output_max) { + // XNNPACK + return xnnpack::internal::available() && + // Weight + (2 == weight.ndimension()) && + (c10::DeviceType::CPU == weight.device().type()) && + (kFloat == weight.scalar_type()) && + // Bias + ((bias && bias->defined()) ? ((1 == bias->ndimension()) && + (c10::DeviceType::CPU == bias->device().type()) && + (kFloat == bias->scalar_type()) && + (weight.size(Layout::Filter::output)) == bias->size(0)) + : true) && + // Output Min / Max + (output_max > output_min) && + true; +} + +Context create( + const Tensor& weight, + const c10::optional& bias, + const float output_min, + const float output_max) { + const Tensor weight_contig = weight.contiguous(); + + TORCH_CHECK( + available( + weight_contig, + bias, + output_min, + output_max), + "XNNPACK Linear not available! " + "Reason: The provided (weight, bias, output_min, output_max) parameters are " + "either invalid individually or their combination is not supported by XNNPACK."); + + xnn_operator_t linear_op{}; + + const xnn_status create_status = xnn_create_fully_connected_nc_f32( + weight_contig.size(Layout::Filter::input), // input_channels + weight_contig.size(Layout::Filter::output), // output_channels + weight_contig.size(Layout::Filter::input), // input_pixel_stride + weight_contig.size(Layout::Filter::output), // output_pixel_stride + weight_contig.data_ptr(), // kernel + (bias && bias->defined()) ? bias->data_ptr() : nullptr, // bias + output_min, // output_min + output_max, // output_max + 0u, // flags + &linear_op); // operator + + TORCH_CHECK( + xnn_status_success == create_status, + "xnn_create_fully_connected_nc_f32 failed!"); + + return Context{ + Operator(linear_op), + { + weight_contig.size(Layout::Filter::output), + } + }; +} + +// TODO: Decouple and improve error handling and messages. +bool usable(const Tensor& input) { + // Input + return (2 <= input.ndimension()) && + (c10::DeviceType::CPU == input.device().type()) && + (kFloat == input.scalar_type()) && + true; +} + +Tensor run( + const Context& context, + const Tensor& input) { + using namespace internal; + + const Tensor& input_contig = input.contiguous(); + + TORCH_CHECK( + usable(input_contig), + "XNNPACK Linear not usable! " + "Reason: The provided input tensor is either invalid or unsupported by XNNPACK."); + + const IntArrayRef input_size = input_contig.sizes(); + std::vector output_size(input_size.cbegin(), input_size.cend()); + output_size.back() = context.output.channels; + + Tensor output = empty_with_tail_padding( + output_size, + input_contig.options().dtype(), + input_contig.suggest_memory_format()); + + const xnn_status setup_status = xnn_setup_fully_connected_nc_f32( + context.linear_op.get(), // operator + Layout::ActivationND::batch(input_contig.sizes()), // Batch, + input_contig.data_ptr(), // input + output.data_ptr(), // output + nullptr); // threadpool + + TORCH_CHECK( + xnn_status_success == setup_status, + "xnn_setup_fully_connected_nc_f32 failed!"); + + const xnn_status run_status = xnn_run_operator( + context.linear_op.get(), // operator + nullptr); // threadpool + + TORCH_INTERNAL_ASSERT( + xnn_status_success == run_status, + "xnn_run_operator failed!"); + + return output; +} + +Tensor create_and_run( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + const float output_min, + const float output_max) { + return run( + create( + weight, + bias, + output_min, + output_max), + input); +} + +} // namespace +} // namespace linear +} // namespace internal + +bool use_linear( + const Tensor& input, + const Tensor& weight, + const Tensor& bias) { + return internal::linear::available( + weight, + bias, + internal::linear::Context::kMin, + internal::linear::Context::kMax) && + internal::linear::usable(input); +} + +Tensor linear( + const Tensor& input, + const Tensor& weight, + const Tensor& bias) { + return internal::linear::create_and_run( + input, + weight, + bias, + internal::linear::Context::kMin, + internal::linear::Context::kMax); +} + +} // namespace xnnpack + +Tensor _linear_prepack( + const Tensor& weight, + const Tensor& bias, + const c10::optional output_min, + const c10::optional output_max) { + return cpp_custom_type_hack::create( + std::make_unique( + xnnpack::internal::linear::create( + weight, + bias, + output_min ? *output_min : xnnpack::internal::linear::Context::kMin, + output_max ? *output_max : xnnpack::internal::linear::Context::kMax)), + weight.options()); +} + +Tensor _linear_packed( + const Tensor& packed_weight, + const Tensor& input) { + return xnnpack::internal::linear::run( + cpp_custom_type_hack::cast(packed_weight), + input); +} + +} // namespace native +} // namespace at + +namespace caffe2 { + +CAFFE_KNOWN_TYPE(at::native::xnnpack::internal::linear::Context); + +} // namespace caffe2 + +#endif /* USE_XNNPACK */ diff --git a/aten/src/ATen/native/xnnpack/Shim.cpp b/aten/src/ATen/native/xnnpack/Shim.cpp new file mode 100644 index 000000000000..d404f20300cf --- /dev/null +++ b/aten/src/ATen/native/xnnpack/Shim.cpp @@ -0,0 +1,96 @@ +#ifndef USE_XNNPACK + +#include + +namespace at { +namespace native { +namespace xnnpack { +namespace internal { +namespace { + +constexpr const char * const kError = + "Not Implemented! Reason: PyTorch not built with XNNPACK support."; + +} // namespace +} // namespace internal + +bool available() { + return false; +} + +bool use_convolution2d( + const Tensor&, + const Tensor&, + const Tensor&, + const IntArrayRef, + const IntArrayRef, + const IntArrayRef, + const int64_t, + const bool) { + return false; +} + +Tensor convolution2d( + const Tensor&, + const Tensor&, + const Tensor&, + const IntArrayRef, + const IntArrayRef, + const IntArrayRef, + const int64_t, + const bool) { + TORCH_CHECK(false, internal::kError); +} + +bool use_linear( + const Tensor&, + const Tensor&, + const Tensor&) { + return false; +} + +Tensor linear( + const Tensor&, + const Tensor&, + const Tensor&) { + TORCH_CHECK(false, internal::kError); +} + +} // namespace xnnpack + +at::Tensor _conv2d_prepack( + const Tensor&, + const Tensor&, + const IntArrayRef, + const IntArrayRef, + const IntArrayRef, + const int64_t, + const c10::optional, + const c10::optional) { + TORCH_CHECK(false, xnnpack::internal::kError); +} + +at::Tensor _conv2d_packed( + const Tensor&, + const Tensor&) { + TORCH_CHECK(false, xnnpack::internal::kError); +} + +Tensor _linear_prepack( + const Tensor&, + const Tensor&, + const c10::optional, + const c10::optional) { + TORCH_CHECK(false, xnnpack::internal::kError); +} + +Tensor _linear_packed( + const Tensor&, + const Tensor&) { + TORCH_CHECK(false, xnnpack::internal::kError); +} + +} // namespace native +} // namespace at + +#endif /* USE_XNNPACK */ diff --git a/aten/src/ATen/quantized/Quantizer.cpp b/aten/src/ATen/quantized/Quantizer.cpp index 082ac0abd587..a42a3697f97a 100644 --- a/aten/src/ATen/quantized/Quantizer.cpp +++ b/aten/src/ATen/quantized/Quantizer.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -478,37 +479,8 @@ QTensorImpl* get_qtensorimpl(const Tensor& self) { // on a different page out of the process's address space. // Here we define a custom allocator that allocates the extra storage required to keep // this behavior safe. This same allocator can be used for FBGEMM as well. -struct QAllocator final : at::Allocator { -public: - virtual ~QAllocator() override = default; - - virtual at::DataPtr allocate(size_t nbytes) const override { - Cast memory{c10::alloc_cpu(kGuard + nbytes)}; - memory.as_byte_ptr += kGuard; - return { - memory.as_void_ptr, - memory.as_void_ptr, - &deleter, - at::Device(at::DeviceType::CPU)}; - } - - virtual at::DeleterFnPtr raw_deleter() const override { - return deleter; - } - - static void deleter(void * const pointer) { - const Cast memory{pointer}; - c10::free_cpu(memory.as_byte_ptr - kGuard); - } - - private: - static constexpr uint32_t kGuard = 8u; - union Cast final { - void * const as_void_ptr; - uint8_t * as_byte_ptr; - }; -}; +using QAllocator = native::GuardingAllocator<8u, 0u>; #endif diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 015640248728..38b38f84f437 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -151,7 +151,6 @@ else() message(FATAL_ERROR "Unrecognized BLAS option: " ${BLAS}) endif() - if (NOT INTERN_BUILD_MOBILE) set(AT_MKL_ENABLED 0) set(AT_MKL_MT 0) @@ -180,85 +179,63 @@ if (NOT INTERN_BUILD_MOBILE) endif() endif() -# Directory where NNPACK and cpuinfo will download and build all dependencies -set(CONFU_DEPENDENCIES_SOURCE_DIR ${PROJECT_BINARY_DIR}/confu-srcs - CACHE PATH "Confu-style dependencies source directory") -set(CONFU_DEPENDENCIES_BINARY_DIR ${PROJECT_BINARY_DIR}/confu-deps - CACHE PATH "Confu-style dependencies binary directory") +# ---[ Dependencies +# NNPACK and family (QNNPACK, PYTORCH_QNNPACK, and XNNPACK) can download and +# compile their dependencies in isolation as part of their build. These dependencies +# are then linked statically with PyTorch. To avoid the possibility of a version +# mismatch between these shared dependencies, explicitly declare our intent to these +# libraries that we are interested in using the exact same source dependencies for all. -# ---[ Eigen BLAS for Mobile -if(INTERN_BUILD_MOBILE AND INTERN_USE_EIGEN_BLAS) - set(USE_BLAS 1) - include(${CMAKE_CURRENT_LIST_DIR}/External/EigenBLAS.cmake) - list(APPEND Caffe2_DEPENDENCY_LIBS eigen_blas) -endif() +if (USE_NNPACK OR USE_QNNPACK OR USE_PYTORCH_QNNPACK OR USE_XNNPACK) + set(DISABLE_NNPACK_AND_FAMILY OFF) -# ---[ pthreadpool -# QNNPACK and NNPACK both depend on pthreadpool, but when building with libtorch -# they should use the pthreadpool implementation under caffe2/utils/threadpool -# instead of the default implementation. To avoid confusion, add pthreadpool -# subdirectory explicitly with EXCLUDE_FROM_ALL property prior to QNNPACK/NNPACK -# does so, which will prevent it from installing the default pthreadpool library. -if(INTERN_BUILD_MOBILE AND NOT BUILD_CAFFE2_MOBILE AND (USE_QNNPACK OR USE_NNPACK)) - if(NOT DEFINED PTHREADPOOL_SOURCE_DIR) - set(CAFFE2_THIRD_PARTY_ROOT "${PROJECT_SOURCE_DIR}/third_party") - set(PTHREADPOOL_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/pthreadpool" CACHE STRING "pthreadpool source directory") - endif() - - IF(NOT TARGET pthreadpool) - SET(PTHREADPOOL_BUILD_TESTS OFF CACHE BOOL "") - SET(PTHREADPOOL_BUILD_BENCHMARKS OFF CACHE BOOL "") - ADD_SUBDIRECTORY( - "${PTHREADPOOL_SOURCE_DIR}" - "${CONFU_DEPENDENCIES_BINARY_DIR}/pthreadpool" - EXCLUDE_FROM_ALL) - ENDIF() -endif() + # Sanity checks - Can we actually build NNPACK and family given the configuration provided? + # Disable them and warn the user if not. -# ---[ QNNPACK -if(USE_QNNPACK) if (IOS) list(LENGTH IOS_ARCH IOS_ARCH_COUNT) if (IOS_ARCH_COUNT GREATER 1) message(WARNING - "Multi-architecture (${IOS_ARCH}) builds are not supported in QNNPACK. " + "Multi-architecture (${IOS_ARCH}) builds are not supported in {Q/X}NNPACK. " "Specify a single architecture in IOS_ARCH and re-configure, or " - "turn this warning off by USE_QNNPACK=OFF.") - set(USE_QNNPACK OFF) + "turn this warning off by USE_{Q/X}NNPACK=OFF.") + set(DISABLE_NNPACK_AND_FAMILY ON) endif() if (NOT IOS_ARCH MATCHES "^(i386|x86_64|armv7.*|arm64.*)$") message(WARNING - "Target architecture \"${IOS_ARCH}\" is not supported in QNNPACK. " + "Target architecture \"${IOS_ARCH}\" is not supported in {Q/X}NNPACK. " "Supported architectures are x86, x86-64, ARM, and ARM64. " - "Turn this warning off by USE_QNNPACK=OFF.") - set(USE_QNNPACK OFF) + "Turn this warning off by USE_{Q/X}NNPACK=OFF.") + set(DISABLE_NNPACK_AND_FAMILY ON) endif() else() if (NOT IOS AND NOT (CMAKE_SYSTEM_NAME MATCHES "^(Android|Linux|Darwin)$")) message(WARNING - "Target platform \"${CMAKE_SYSTEM_NAME}\" is not supported in QNNPACK. " + "Target platform \"${CMAKE_SYSTEM_NAME}\" is not supported in {Q/X}NNPACK. " "Supported platforms are Android, iOS, Linux, and macOS. " - "Turn this warning off by USE_QNNPACK=OFF.") - set(USE_QNNPACK OFF) + "Turn this warning off by USE_{Q/X}NNPACK=OFF.") + set(DISABLE_NNPACK_AND_FAMILY ON) endif() if (NOT IOS AND NOT (CMAKE_SYSTEM_PROCESSOR MATCHES "^(i686|AMD64|x86_64|armv[0-9].*|arm64|aarch64)$")) message(WARNING - "Target architecture \"${CMAKE_SYSTEM_PROCESSOR}\" is not supported in QNNPACK. " + "Target architecture \"${CMAKE_SYSTEM_PROCESSOR}\" is not supported in {Q/X}NNPACK. " "Supported architectures are x86, x86-64, ARM, and ARM64. " - "Turn this warning off by USE_QNNPACK=OFF.") - set(USE_QNNPACK OFF) + "Turn this warning off by USE_{Q/X}NNPACK=OFF.") + set(DISABLE_NNPACK_AND_FAMILY ON) endif() endif() - if (USE_QNNPACK) + + if (DISABLE_NNPACK_AND_FAMILY) + set(USE_NNPACK OFF) + set(USE_QNNPACK OFF) + set(USE_PYTORCH_QNNPACK OFF) + set(USE_XNNPACK OFF) + else() set(CAFFE2_THIRD_PARTY_ROOT "${PROJECT_SOURCE_DIR}/third_party") - # Directories for QNNPACK dependencies submoduled in Caffe2 if (NOT DEFINED CPUINFO_SOURCE_DIR) set(CPUINFO_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/cpuinfo" CACHE STRING "cpuinfo source directory") endif() - if (NOT DEFINED QNNPACK_SOURCE_DIR) - set(QNNPACK_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/QNNPACK" CACHE STRING "QNNPACK source directory") - endif() if (NOT DEFINED FP16_SOURCE_DIR) set(FP16_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/FP16" CACHE STRING "FP16 source directory") endif() @@ -272,26 +249,70 @@ if(USE_QNNPACK) set(PTHREADPOOL_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/pthreadpool" CACHE STRING "pthreadpool source directory") endif() - if(NOT TARGET qnnpack) - set(QNNPACK_BUILD_TESTS OFF CACHE BOOL "") - set(QNNPACK_BUILD_BENCHMARKS OFF CACHE BOOL "") - set(QNNPACK_CUSTOM_THREADPOOL ON CACHE BOOL "") - set(QNNPACK_LIBRARY_TYPE "static" CACHE STRING "") - set(PTHREADPOOL_LIBRARY_TYPE "static" CACHE STRING "") - set(CPUINFO_LIBRARY_TYPE "static" CACHE STRING "") - set(CPUINFO_LOG_LEVEL "error" CACHE STRING "") - add_subdirectory( - "${QNNPACK_SOURCE_DIR}" - "${CONFU_DEPENDENCIES_BINARY_DIR}/QNNPACK") - # We build static versions of QNNPACK and pthreadpool but link - # them into a shared library for Caffe2, so they need PIC. - set_property(TARGET qnnpack PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET pthreadpool PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET cpuinfo PROPERTY POSITION_INDEPENDENT_CODE ON) - endif() + set(CPUINFO_LIBRARY_TYPE "static" CACHE STRING "") + set(CPUINFO_LOG_LEVEL "error" CACHE STRING "") + set(PTHREADPOOL_LIBRARY_TYPE "static" CACHE STRING "") + endif() +endif() + +set(CONFU_DEPENDENCIES_SOURCE_DIR ${PROJECT_BINARY_DIR}/confu-srcs + CACHE PATH "Confu-style dependencies source directory") +set(CONFU_DEPENDENCIES_BINARY_DIR ${PROJECT_BINARY_DIR}/confu-deps + CACHE PATH "Confu-style dependencies binary directory") + +# ---[ Eigen BLAS for Mobile +if(INTERN_BUILD_MOBILE AND INTERN_USE_EIGEN_BLAS) + set(USE_BLAS 1) + include(${CMAKE_CURRENT_LIST_DIR}/External/EigenBLAS.cmake) + list(APPEND Caffe2_DEPENDENCY_LIBS eigen_blas) +endif() + +# ---[ pthreadpool +# QNNPACK and NNPACK both depend on pthreadpool, but when building with libtorch +# they should use the pthreadpool implementation under caffe2/utils/threadpool +# instead of the default implementation. To avoid confusion, add pthreadpool +# subdirectory explicitly with EXCLUDE_FROM_ALL property prior to QNNPACK/NNPACK +# does so, which will prevent it from installing the default pthreadpool library. +if(INTERN_BUILD_MOBILE AND NOT BUILD_CAFFE2_MOBILE AND (USE_QNNPACK OR USE_NNPACK OR USE_XNNPACK)) + if(NOT DEFINED PTHREADPOOL_SOURCE_DIR) + set(CAFFE2_THIRD_PARTY_ROOT "${PROJECT_SOURCE_DIR}/third_party") + set(PTHREADPOOL_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/pthreadpool" CACHE STRING "pthreadpool source directory") + endif() + + IF(NOT TARGET pthreadpool) + SET(PTHREADPOOL_BUILD_TESTS OFF CACHE BOOL "") + SET(PTHREADPOOL_BUILD_BENCHMARKS OFF CACHE BOOL "") + ADD_SUBDIRECTORY( + "${PTHREADPOOL_SOURCE_DIR}" + "${CONFU_DEPENDENCIES_BINARY_DIR}/pthreadpool" + EXCLUDE_FROM_ALL) + ENDIF() +endif() - list(APPEND Caffe2_DEPENDENCY_LIBS qnnpack) +# ---[ QNNPACK +if(USE_QNNPACK) + set(CAFFE2_THIRD_PARTY_ROOT "${PROJECT_SOURCE_DIR}/third_party") + + if (NOT DEFINED QNNPACK_SOURCE_DIR) + set(QNNPACK_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/QNNPACK" CACHE STRING "QNNPACK source directory") endif() + + if(NOT TARGET qnnpack) + set(QNNPACK_BUILD_TESTS OFF CACHE BOOL "") + set(QNNPACK_BUILD_BENCHMARKS OFF CACHE BOOL "") + set(QNNPACK_CUSTOM_THREADPOOL ON CACHE BOOL "") + set(QNNPACK_LIBRARY_TYPE "static" CACHE STRING "") + add_subdirectory( + "${QNNPACK_SOURCE_DIR}" + "${CONFU_DEPENDENCIES_BINARY_DIR}/QNNPACK") + # We build static versions of QNNPACK and pthreadpool but link + # them into a shared library for Caffe2, so they need PIC. + set_property(TARGET qnnpack PROPERTY POSITION_INDEPENDENT_CODE ON) + set_property(TARGET pthreadpool PROPERTY POSITION_INDEPENDENT_CODE ON) + set_property(TARGET cpuinfo PROPERTY POSITION_INDEPENDENT_CODE ON) + endif() + + list(APPEND Caffe2_DEPENDENCY_LIBS qnnpack) endif() # ---[ Caffe2 Int8 operators (enabled by USE_QNNPACK) depend on gemmlowp and neon2sse headers @@ -303,63 +324,26 @@ endif() # ---[ PYTORCH_QNNPACK if(USE_PYTORCH_QNNPACK) - if (IOS) - list(LENGTH IOS_ARCH IOS_ARCH_COUNT) - if (IOS_ARCH_COUNT GREATER 1) - message(WARNING - "Multi-architecture (${IOS_ARCH}) builds are not supported in QNNPACK. " - "Specify a single architecture in IOS_ARCH and re-configure, or " - "turn this warning off by USE_PYTORCH_QNNPACK=OFF.") - set(USE_PYTORCH_QNNPACK OFF) + if (NOT DEFINED PYTORCH_QNNPACK_SOURCE_DIR) + set(PYTORCH_QNNPACK_SOURCE_DIR "${PROJECT_SOURCE_DIR}/aten/src/ATen/native/quantized/cpu/qnnpack" CACHE STRING "QNNPACK source directory") endif() - if (NOT IOS_ARCH MATCHES "^(i386|x86_64|armv7.*|arm64.*)$") - message(WARNING - "Target architecture \"${IOS_ARCH}\" is not supported in QNNPACK. " - "Supported architectures are x86, x86-64, ARM, and ARM64. " - "Turn this warning off by USE_PYTORCH_QNNPACK=OFF.") - set(USE_PYTORCH_QNNPACK OFF) - endif() - else() - if (NOT IOS AND NOT (CMAKE_SYSTEM_NAME MATCHES "^(Android|Linux|Darwin)$")) - message(WARNING - "Target platform \"${CMAKE_SYSTEM_NAME}\" is not supported in QNNPACK. " - "Supported platforms are Android, iOS, Linux, and macOS. " - "Turn this warning off by USE_PYTORCH_QNNPACK=OFF.") - set(USE_PYTORCH_QNNPACK OFF) - endif() - if (NOT IOS AND NOT (CMAKE_SYSTEM_PROCESSOR MATCHES "^(i686|AMD64|x86_64|armv[0-9].*|arm64|aarch64)$")) - message(WARNING - "Target architecture \"${CMAKE_SYSTEM_PROCESSOR}\" is not supported in QNNPACK. " - "Supported architectures are x86, x86-64, ARM, and ARM64. " - "Turn this warning off by USE_PYTORCH_QNNPACK=OFF.") - set(USE_PYTORCH_QNNPACK OFF) - endif() - endif() - if (USE_PYTORCH_QNNPACK) - if (NOT DEFINED PYTORCH_QNNPACK_SOURCE_DIR) - set(PYTORCH_QNNPACK_SOURCE_DIR "${PROJECT_SOURCE_DIR}/aten/src/ATen/native/quantized/cpu/qnnpack" CACHE STRING "QNNPACK source directory") - endif() - if(NOT TARGET pytorch_qnnpack) - set(PYTORCH_QNNPACK_BUILD_TESTS OFF CACHE BOOL "") - set(PYTORCH_QNNPACK_BUILD_BENCHMARKS OFF CACHE BOOL "") - set(PYTORCH_QNNPACK_CUSTOM_THREADPOOL ON CACHE BOOL "") - set(PYTORCH_QNNPACK_LIBRARY_TYPE "static" CACHE STRING "") - set(PTHREADPOOL_LIBRARY_TYPE "static" CACHE STRING "") - set(CPUINFO_LIBRARY_TYPE "static" CACHE STRING "") - set(CPUINFO_LOG_LEVEL "error" CACHE STRING "") - add_subdirectory( - "${PYTORCH_QNNPACK_SOURCE_DIR}" - "${CONFU_DEPENDENCIES_BINARY_DIR}/pytorch_qnnpack") - # We build static versions of QNNPACK and pthreadpool but link - # them into a shared library for Caffe2, so they need PIC. - set_property(TARGET pytorch_qnnpack PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET pthreadpool PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET cpuinfo PROPERTY POSITION_INDEPENDENT_CODE ON) - endif() + if(NOT TARGET pytorch_qnnpack) + set(PYTORCH_QNNPACK_BUILD_TESTS OFF CACHE BOOL "") + set(PYTORCH_QNNPACK_BUILD_BENCHMARKS OFF CACHE BOOL "") + set(PYTORCH_QNNPACK_CUSTOM_THREADPOOL ON CACHE BOOL "") + set(PYTORCH_QNNPACK_LIBRARY_TYPE "static" CACHE STRING "") + add_subdirectory( + "${PYTORCH_QNNPACK_SOURCE_DIR}" + "${CONFU_DEPENDENCIES_BINARY_DIR}/pytorch_qnnpack") + # We build static versions of QNNPACK and pthreadpool but link + # them into a shared library for Caffe2, so they need PIC. + set_property(TARGET pytorch_qnnpack PROPERTY POSITION_INDEPENDENT_CODE ON) + set_property(TARGET pthreadpool PROPERTY POSITION_INDEPENDENT_CODE ON) + set_property(TARGET cpuinfo PROPERTY POSITION_INDEPENDENT_CODE ON) + endif() - list(APPEND Caffe2_DEPENDENCY_LIBS pytorch_qnnpack) - endif() + list(APPEND Caffe2_DEPENDENCY_LIBS pytorch_qnnpack) endif() # ---[ NNPACK @@ -379,6 +363,33 @@ if(USE_NNPACK) endif() endif() +# ---[ XNNPACK +if(USE_XNNPACK) + if (NOT DEFINED XNNPACK_SOURCE_DIR) + set(XNNPACK_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/XNNPACK" CACHE STRING "XNNPACK source directory") + endif() + + if (NOT DEFINED XNNPACK_INCLUDE_DIR) + set(XNNPACK_INCLUDE_DIR "${XNNPACK_SOURCE_DIR}/include" CACHE STRING "XNNPACK include directory") + endif() + + if(NOT TARGET XNNPACK) + set(XNNPACK_CUSTOM_THREADPOOL ON CACHE BOOL "") + set(XNNPACK_LIBRARY_TYPE "static" CACHE STRING "") + set(XNNPACK_BUILD_BENCHMARKS OFF CACHE BOOL "") + set(XNNPACK_BUILD_TESTS OFF CACHE BOOL "") + + add_subdirectory( + "${XNNPACK_SOURCE_DIR}" + "${CONFU_DEPENDENCIES_BINARY_DIR}/XNNPACK") + + set_property(TARGET XNNPACK PROPERTY POSITION_INDEPENDENT_CODE ON) + endif() + + include_directories(SYSTEM ${XNNPACK_INCLUDE_DIR}) + list(APPEND Caffe2_DEPENDENCY_LIBS XNNPACK) +endif() + # ---[ Caffe2 uses cpuinfo library in the thread pool if (NOT TARGET cpuinfo) if (NOT DEFINED CPUINFO_SOURCE_DIR) diff --git a/cmake/TorchConfig.cmake.in b/cmake/TorchConfig.cmake.in index c11d54625df5..a77c9ac06b17 100644 --- a/cmake/TorchConfig.cmake.in +++ b/cmake/TorchConfig.cmake.in @@ -64,6 +64,11 @@ if (NOT @BUILD_SHARED_LIBS@) list(APPEND TORCH_LIBRARIES ${PYTORCH_QNNPACK_LIBRARY}) endif() + if (@USE_XNNPACK@) + find_library(XNNPACK_LIBRARY XNNPACK PATHS "${TORCH_INSTALL_PREFIX}/lib") + list(APPEND TORCH_LIBRARIES ${XNNPACK_LIBRARY}) + endif() + if (@INTERN_USE_EIGEN_BLAS@) find_library(EIGEN_BLAS_LIBRARY eigen_blas PATHS "${TORCH_INSTALL_PREFIX}/lib") list(APPEND TORCH_LIBRARIES ${EIGEN_BLAS_LIBRARY}) diff --git a/scripts/xcode_build.rb b/scripts/xcode_build.rb index 8faae2da1f50..ab585b1b9351 100644 --- a/scripts/xcode_build.rb +++ b/scripts/xcode_build.rb @@ -2,7 +2,7 @@ require 'xcodeproj' options = {} -option_parser = OptionParser.new do |opts| +option_parser = OptionParser.new do |opts| opts.banner = 'Tools for building PyTorch iOS framework on MacOS' opts.on('-i', '--install_path ', 'path to the cmake install folder') { |value| options[:install] = value @@ -23,11 +23,11 @@ puts options.inspect install_path = File.expand_path(options[:install]) -if not Dir.exist? (install_path) +if not Dir.exist? (install_path) raise "path don't exist:#{install_path}!" end xcodeproj_path = File.expand_path(options[:xcodeproj]) -if not File.exist? (xcodeproj_path) +if not File.exist? (xcodeproj_path) raise "path don't exist:#{xcodeproj_path}!" end @@ -51,8 +51,8 @@ # link static libraries target.frameworks_build_phases.clear -libs = ['libc10.a', 'libclog.a', 'libnnpack.a', 'libeigen_blas.a', 'libcpuinfo.a', 'libpytorch_qnnpack.a', 'libtorch_cpu.a', 'libtorch.a'] -for lib in libs do +libs = ['libc10.a', 'libclog.a', 'libnnpack.a', 'libXNNPACK.a', 'libeigen_blas.a', 'libcpuinfo.a', 'libpytorch_qnnpack.a', 'libtorch_cpu.a', 'libtorch.a'] +for lib in libs do path = "#{install_path}/lib/#{lib}" if File.exist?(path) libref = project.frameworks_group.new_file(path) @@ -68,12 +68,12 @@ sdk = 'iphoneos' else raise "unsupported platform #{options[:platform]}" -end +end profile = options[:profile] -if not profile +if not profile raise "no provisioning profile found!" -end +end # run xcodebuild exec "xcodebuild clean build -project #{xcodeproj_path} -target #{target.name} -sdk #{sdk} -configuration Release PROVISIONING_PROFILE_SPECIFIER=#{profile}" diff --git a/third_party/XNNPACK b/third_party/XNNPACK new file mode 160000 index 000000000000..fa611cc5c241 --- /dev/null +++ b/third_party/XNNPACK @@ -0,0 +1 @@ +Subproject commit fa611cc5c2415b330282075167ce5580c620556d diff --git a/third_party/cpuinfo b/third_party/cpuinfo index 89fe1695edf9..0e6bde92b343 160000 --- a/third_party/cpuinfo +++ b/third_party/cpuinfo @@ -1 +1 @@ -Subproject commit 89fe1695edf9ee14c22f815f24bac45577a4f135 +Subproject commit 0e6bde92b343c5fbcfe34ecd41abf9515d54b4a7 diff --git a/third_party/psimd b/third_party/psimd index 90a938f30ba4..10b4ffc6ea9e 160000 --- a/third_party/psimd +++ b/third_party/psimd @@ -1 +1 @@ -Subproject commit 90a938f30ba414ada2f4b00674ee9631d7d85e19 +Subproject commit 10b4ffc6ea9e2e11668f86969586f88bc82aaefa diff --git a/third_party/pthreadpool b/third_party/pthreadpool index 13da0b4c21d1..d465747660ec 160000 --- a/third_party/pthreadpool +++ b/third_party/pthreadpool @@ -1 +1 @@ -Subproject commit 13da0b4c21d17f94150713366420baaf1b5a46f4 +Subproject commit d465747660ecf9ebbaddf8c3db37e4a13d0c9103