Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 69 additions & 8 deletions backends/vulkan/runtime/graph/ops/impl/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Conv2dGemm.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>

#include <executorch/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h>
Expand Down Expand Up @@ -296,6 +297,41 @@ Conv2dMethod get_conv2d_method(
return Conv2dMethod::SlidingWindow;
}

// Decide whether a SlidingWindow conv2d should be computed via the
// im2col + GEMM path (conv2d_gemm_impl) instead of the direct convolution
// shader. Across 26 configs on Mali-G715 (buffer path) and Adreno SM8650
// (texture path): FP32 cases were numerically verified against the reference;
// FP16 cases were routing/dispatch-validated only (the reference is float-only
// for the large shapes, so FP16 outputs were not numerically checked).
//
// Only called for SlidingWindow conv2d (1x1 is routed to conv2d_pw and
// Depthwise/Transposed are handled before the call site).
//
// Preconditions (fall back to direct conv if any fail — the im2col path is
// either not applicable or not beneficial):
// - groups == 1
// - dilation == 1 (all dims)
//
// Selection rule: use im2col on Mali universally, or once the output channel
// count is large enough to amortize the fixed ~N*K_total im2col gather cost.
constexpr int64_t kIm2colMinCOut = 128;

bool should_use_conv2d_im2col(
ComputeGraph& graph,
const ValueRef weight_data,
const int64_t groups_val,
const Kernel2dParams& kernel_params) {
if (groups_val != 1) {
return false;
}
if (kernel_params.dilation[0] != 1 || kernel_params.dilation[1] != 1) {
return false;
}
const auto weight_sizes = graph.sizes_of(weight_data);
const int64_t c_out = weight_sizes.at(0);
return graph.device_is_mali() || c_out >= kIm2colMinCOut;
}

utils::uvec3 create_conv2d_global_wg_size(
ComputeGraph& graph,
const Conv2dMethod method,
Expand Down Expand Up @@ -425,7 +461,8 @@ void add_conv2d_node(
const ValueRef out_min,
const ValueRef out_max,
const ValueRef out,
const bool clamp_out) {
const bool clamp_out,
const bool force_direct) {
const bool transposed_val = graph.get_bool(transposed);

float out_min_val = 0.0f;
Expand Down Expand Up @@ -473,6 +510,37 @@ void add_conv2d_node(
out_max_val);
}

const Kernel2dParams kernel_params = create_kernel2d_params(
graph,
weight_data,
/*kernel_size_only = */ false,
stride,
padding,
dilation);

// SlidingWindow conv2d: route to the im2col + GEMM path when the heuristic
// indicates it is beneficial, falling back to the direct convolution shader
// otherwise. `force_direct` bypasses the heuristic entirely and forces the
// direct path (used by tests to exercise the direct shader regardless of
// device); the default (false) reproduces the production routing exactly.
const bool use_im2col = !force_direct &&
method == Conv2dMethod::SlidingWindow &&
should_use_conv2d_im2col(graph, weight_data, groups_val, kernel_params);
if (use_im2col) {
return conv2d_gemm_impl(
graph,
in,
weight_data,
bias,
stride,
padding,
dilation,
out,
clamp_out,
out_min_val,
out_max_val);
}

ValueRef arg_weight = prepack_weights(graph, weight_data, method);
ValueRef arg_bias = prepack_biases(
graph,
Expand All @@ -489,13 +557,6 @@ void add_conv2d_node(

check_conv_args(graph, in, out);

Kernel2dParams kernel_params = create_kernel2d_params(
graph,
weight_data,
/*kernel_size_only = */ false,
stride,
padding,
dilation);
Conv2dParams extra_params =
create_conv2d_params(graph, weight_data, kernel_params, transposed_val);

Expand Down
22 changes: 22 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/Convolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,26 @@ void resize_conv2d_node(
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& extra_args);

// `force_direct` overrides the im2col-vs-direct routing heuristic: when true,
// a SlidingWindow conv2d always takes the direct sliding-window path,
// bypassing should_use_conv2d_im2col(). The default (false) preserves the
// production routing exactly. Pointwise / Depthwise / Transposed methods are
// unaffected by this flag.
void add_conv2d_node(
ComputeGraph& graph,
const ValueRef in,
const ValueRef weight_data,
const ValueRef bias,
const ValueRef stride,
const ValueRef padding,
const ValueRef dilation,
const ValueRef transposed,
const ValueRef output_padding,
const ValueRef groups,
const ValueRef out_min,
const ValueRef out_max,
const ValueRef out,
const bool clamp_out,
const bool force_direct = false);

} // namespace vkcompute
31 changes: 30 additions & 1 deletion backends/vulkan/test/custom_ops/impl/TestConv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Conv2dGemm.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Convolution.h>

#include <optional>

Expand All @@ -29,7 +30,10 @@ void test_conv2d(ComputeGraph& graph, const std::vector<ValueRef>& args) {
// args[10] = output [N, C_out, H_out, W_out]
//
// impl_selector grammar:
// "" -> aten.convolution.default (direct sliding-window)
// "" -> aten.convolution.default (heuristic-routed:
// should_use_conv2d_im2col() picks direct vs im2col)
// "direct" -> add_conv2d_node(force_direct=true): forces the direct
// sliding-window path, bypassing the routing heuristic
// "im2col" -> et_vk.conv2d_gemm.default, auto im2col storage
// "im2col_buffer"-> im2col/GEMM, force buffer im2col intermediate
// "im2col_tex2d" -> im2col/GEMM, force texture2d im2col intermediate
Expand Down Expand Up @@ -88,6 +92,31 @@ void test_conv2d(ComputeGraph& graph, const std::vector<ValueRef>& args) {
graph.add_scalar_list<int64_t>(std::vector<int64_t>{0, 0});
ValueRef groups = graph.add_scalar<int64_t>(1);

// The "direct" selector must reach the exact direct sliding-window dispatch
// the heuristic would otherwise pick. The registered op can only route via
// the heuristic, so call add_conv2d_node directly with force_direct=true to
// bypass it (mirroring how the forced-storage variants call
// conv2d_gemm_impl).
if (impl_selector == "direct") {
add_conv2d_node(
graph,
input,
weight,
bias,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
/*out_min=*/kDummyValueRef,
/*out_max=*/kDummyValueRef,
out,
/*clamp_out=*/false,
/*force_direct=*/true);
return;
}

const std::string target_op = (impl_selector == "im2col")
? "et_vk.conv2d_gemm.default"
: "aten.convolution.default";
Expand Down
56 changes: 52 additions & 4 deletions backends/vulkan/test/custom_ops/test_conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -478,13 +478,49 @@ static std::vector<TestCase> generate_conv2d_test_cases() {
false},
};

// Two implementation variants: direct sliding-window (default) and im2col.
const std::vector<std::string> impls = {"", "im2col"};
// Boundary pair straddling the should_use_conv2d_im2col() c_out >= 128
// routing threshold. Spatial dims are tiny (8x8) so the FP32 float reference
// stays cheap, but c_out = 64 / 128 are both >= kRefDimSizeLimit, so these
// get the PERF label. FP32 PERF cases are still numerically VERIFIED (the
// reference's invalid_argument throw that skips the check only fires for
// half), so both implementations are cross-checked against the float
// reference at the boundary. Run all three impls: at c_out = 64 the heuristic
// ("") picks direct on Adreno / im2col on Mali; at c_out = 128 it picks
// im2col on both — and "direct"/"im2col" force each path regardless, proving
// the two implementations agree at the boundary on either device.
std::vector<Conv2dTestConfig> boundary_configs = {
// c_out = 64 (< 128): below the threshold
{InputDims(1, 16, 8, 8),
64,
KernelSize(3, 3),
Stride(1, 1),
Padding(1, 1),
Dilation(1, 1),
false},
// c_out = 128 (== 128): at/above the threshold
{InputDims(1, 16, 8, 8),
128,
KernelSize(3, 3),
Stride(1, 1),
Padding(1, 1),
Dilation(1, 1),
false},
};

// Implementation variants exercised for every small ACCU shape:
// "" -> heuristic-routed (should_use_conv2d_im2col picks direct on
// Adreno for small c_out, im2col on Mali)
// "im2col" -> forced im2col/GEMM path
// "direct" -> forced direct sliding-window path (force_direct=true)
// Including "direct" guarantees the direct shader gets reference-checked on
// BOTH devices — without it, Mali would always route "" to im2col and never
// exercise the direct path.
const std::vector<std::string> impls = {"", "im2col", "direct"};
// Forced-storage im2col variants for the per-variant ACCU coverage.
const std::vector<std::string> forced_storage_impls = {
"im2col_buffer", "im2col_tex2d", "im2col_tex3d"};

// Generate accuracy test cases for both impls and both dtypes. FP16 small
// Generate accuracy test cases for all impls and both dtypes. FP16 small
// shapes get a real reference check (gated in conv2d_reference_impl); we run
// both dtypes so we catch correctness regressions in either path. Large-K
// half stays timing-only via the reference's PERF-shape throw.
Expand Down Expand Up @@ -515,7 +551,19 @@ static std::vector<TestCase> generate_conv2d_test_cases() {
}
}

// Generate performance test cases (float and half) for both impls.
// Generate the c_out boundary pair (FP32 only) through all three impls.
// FP32 PERF cases are reference-VERIFIED, so the direct and im2col paths are
// both cross-checked against the float reference at the routing threshold.
for (const auto& config : boundary_configs) {
for (auto st : storage_types) {
for (const auto& impl : impls) {
test_cases.push_back(
create_conv2d_test_case(config, vkapi::kFloat, st, layout, impl));
}
}
}

// Generate performance test cases (float and half) for all impls.
for (const auto& config : perf_configs) {
std::vector<vkapi::ScalarType> dtypes = {vkapi::kFloat, vkapi::kHalf};
for (auto dtype : dtypes) {
Expand Down
Loading