diff --git a/kernels/portable/cpu/util/slice_util.cpp b/kernels/portable/cpu/util/slice_util.cpp index 05e2f7d8289..659e3e46659 100644 --- a/kernels/portable/cpu/util/slice_util.cpp +++ b/kernels/portable/cpu/util/slice_util.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include namespace torch { @@ -202,12 +203,45 @@ void compute_slice( InvalidArgument, /* void */, "out.nbytes() is smaller than the expected slice size."); - for (const auto i : c10::irange(leading_dims)) { - const char* src = input_data + (i * dim_length + start) * length_per_step; - for ([[maybe_unused]] const auto j : c10::irange(length)) { - memcpy(dest, src, length_per_step); - src += step * length_per_step; - dest += length_per_step; + // Thresholds for enabling multithreading: + // - Minimum number of leading dimensions: 8 + // - Minimum total elements to copy: 32768 (GRAIN_SIZE) + constexpr int64_t MIN_LEADING_DIMS_FOR_MT = 8; + constexpr int64_t MIN_ELEMENTS_FOR_MT = + executorch::extension::internal::GRAIN_SIZE; + + const int64_t total_elements = leading_dims * length * trailing_dims; + const bool use_multithreading = leading_dims >= MIN_LEADING_DIMS_FOR_MT && + total_elements >= MIN_ELEMENTS_FOR_MT; + + if (use_multithreading) { + // Use parallel_for to distribute work across leading dimensions + // Calculate grain size based on number of elements per leading dimension + const int64_t elements_per_leading_dim = length * trailing_dims; + const int64_t grain_size = MIN_LEADING_DIMS_FOR_MT; + + executorch::extension::parallel_for( + 0, leading_dims, grain_size, [&](const auto begin, const auto end) { + for (const auto i : c10::irange(begin, end)) { + const char* src = + input_data + (i * dim_length + start) * length_per_step; + char* local_dest = dest + i * length * length_per_step; + for ([[maybe_unused]] const auto j : c10::irange(length)) { + memcpy(local_dest, src, length_per_step); + src += step * length_per_step; + local_dest += length_per_step; + } + } + }); + } else { + // Single-threaded path for small workloads + for (const auto i : c10::irange(leading_dims)) { + const char* src = input_data + (i * dim_length + start) * length_per_step; + for ([[maybe_unused]] const auto j : c10::irange(length)) { + memcpy(dest, src, length_per_step); + src += step * length_per_step; + dest += length_per_step; + } } } } diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index 84d0712c033..402a0934ef5 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -292,6 +292,7 @@ def define_common_targets(): exported_headers = ["slice_util.h"], deps = [ "//executorch/runtime/kernel:kernel_includes", + "//executorch/extension/threadpool:threadpool", ], visibility = ["//executorch/kernels/portable/cpu/..."], )