Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ bool can_use_choose_qparams_per_row(
void choose_qparams_affine_impl(
ComputeGraph& graph,
const std::vector<ValueRef>& args) {
int arg_idx = 0;
size_t arg_idx = 0;
size_t last_arg_idx = args.size() - 1;
const ValueRef input = args[arg_idx++];
const ValueRef mapping_type = args[arg_idx++];
(void)mapping_type;
Expand All @@ -170,7 +171,8 @@ void choose_qparams_affine_impl(
(void)eps;
const ValueRef scale_dtype = args[arg_idx++];
const ValueRef zero_point_dtype = args[arg_idx++];
const ValueRef out_tuple_ref = args[arg_idx++];

const ValueRef out_tuple_ref = args[last_arg_idx];

// Suppress unused variable warnings
(void)target_dtype;
Expand Down
10 changes: 6 additions & 4 deletions backends/vulkan/runtime/graph/ops/impl/QuantizeDequantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,8 @@ void add_unpack_4w4c_and_dequantize_node(
void quantize_per_tensor_impl(
ComputeGraph& graph,
const std::vector<ValueRef>& args) {
int32_t arg_idx = 0;
size_t arg_idx = 0;
size_t last_arg_idx = args.size() - 1;
const ValueRef fp_input = args[arg_idx++];
const ValueRef scale = args[arg_idx++];
const ValueRef zero_point = args[arg_idx++];
Expand All @@ -380,7 +381,7 @@ void quantize_per_tensor_impl(
const ValueRef dtype = args[arg_idx++];
(void)dtype;

const ValueRef int8_output = args[arg_idx++];
const ValueRef int8_output = args[last_arg_idx];

VK_CHECK_COND(
graph.estimate_memory_layout_of(int8_output) == utils::kPackedInt8_4W4C);
Expand All @@ -392,7 +393,8 @@ void quantize_per_tensor_impl(
void dequantize_per_tensor_impl(
ComputeGraph& graph,
const std::vector<ValueRef>& args) {
int32_t arg_idx = 0;
size_t arg_idx = 0;
size_t last_arg_idx = args.size() - 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the ones in between do you just rely on default behavior? What if they are serialized with values != default? Shouldnt you error out?

Copy link
Contributor Author

@SS-JIA SS-JIA Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, that's a good point. I have a planned updated to improve arg checking for quantized_decomposed ops since there are currently a lot of unsupported input cases which are not accounted for - I will include this in that update.

The primary purpose of this PR as-is is to recover a currently broken CI signal, so I would prefer to keep it as simple as possible. In practice, not validating the args should be ok (for now) since the quantized_decomposed ops are inserted by a quantization workflow and Vulkan doesn't really work with non-supported quant workflows anyways 😛

const ValueRef int8_input = args[arg_idx++];
const ValueRef scale = args[arg_idx++];
const ValueRef zero_point = args[arg_idx++];
Expand All @@ -405,7 +407,7 @@ void dequantize_per_tensor_impl(
const ValueRef output_dtype = args[arg_idx++];
(void)output_dtype;

const ValueRef fp_output = args[arg_idx++];
const ValueRef fp_output = args[last_arg_idx];

VK_CHECK_COND(
graph.estimate_memory_layout_of(int8_input) == utils::kPackedInt8_4W4C);
Expand Down
Loading