Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 23 additions & 36 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,9 @@ def register_ephemeral_op(features: OpFeatures):

@update_features(
[
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
Expand Down Expand Up @@ -276,14 +276,32 @@ def register_quantization_op(features: OpFeatures):
[
exir_ops.edge.torchao.quantize_affine.default,
exir_ops.edge.torchao.dequantize_affine.default,
]
)
def register_affine_quantization_op(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
uses_axis_map=False,
valid_packed_dims={PackedDim.WIDTH},
)
features.buffer_impl = True
features.resize_fn = True
features.optimal_storage = VkStorageType.TEXTURE_3D
features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
features.handles_own_prepacking = True

return features


@update_features(
[
exir_ops.edge.torchao.choose_qparams_affine.default,
]
)
def register_torchao_quantization_op(features: OpFeatures):
# TorchAO quantization operators - default to per-tensor behavior
# Same features as standard quantization ops
def register_choose_qparams_affine_op(features: OpFeatures):
# Currently only created a rudimentary buffer implementation for choose_qparams_affine
# since the reduction logic for blocks in texture3d is not trivial to implement in vulkan.
features.texture_impl = TextureImplFeatures(
uses_axis_map=True,
uses_axis_map=False,
valid_packed_dims={
PackedDim.WIDTH,
},
Expand All @@ -292,37 +310,6 @@ def register_torchao_quantization_op(features: OpFeatures):
features.resize_fn = True
features.optimal_storage = VkStorageType.BUFFER

def check_torchao_quantization_node(node: torch.fx.Node) -> bool:
# Only per-tensor quantization is supported by the Vulkan backend.
if len(node.args) < 2:
return False

block_size = node.args[1]

if not isinstance(block_size, (list, tuple)):
return False

input_arg = node.args[0]
if not isinstance(input_arg, torch.fx.Node):
return False

input_tensor = input_arg.meta.get("val", None)
if not isinstance(input_tensor, FakeTensor):
return False

input_shape = list(input_tensor.shape)

if len(block_size) != len(input_shape):
return False

# Check if block_size matches input_shape exactly (per-tensor quantization)
for i in range(len(block_size)):
if block_size[i] != input_shape[i]:
return False

return True

features.check_node_fn = check_torchao_quantization_node
return features


Expand Down
100 changes: 54 additions & 46 deletions backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh
Original file line number Diff line number Diff line change
Expand Up @@ -9,59 +9,67 @@
#ifndef CHOOSE_QPARAMS_GLSLH
#define CHOOSE_QPARAMS_GLSLH

// Calculate scale and zero point from min and max values
void calculate_scale_and_zero_point(
float min_val,
float max_val,
int qmin,
int qmax,
float eps_threshold,
out float scale_val,
out int zero_point_val) {
// ensure we have zero included in our range
min_val = min(min_val, 0.0);
max_val = max(max_val, 0.0);
// mapping_type : 0 = ASYM, 1 = SYM, 2 = SYM_NO_CLIP
void calc_scale_zp(
float lo, float hi,
int qmin, int qmax,
int mapping_type,
float eps,
out float scale, out int zp) {
// Handle case where lo and hi are +/-INF (no valid values found)
if (isinf(lo) || isinf(hi)) {
lo = 0.0;
hi = 0.0;
}

scale_val = (max_val - min_val) / float(qmax - qmin);
float minv = min(lo, 0.0);
float maxv = max(hi, 0.0);

// Handle zero or very small scale
if (scale_val == 0.0 || isinf(1.0 / scale_val)) {
scale_val = 0.1;
}
if (mapping_type == 0) { // asymmetric
scale = (maxv - minv) / float(qmax - qmin);

// Handle zero or very small scale
if (scale == 0.0 || isinf(1.0/scale)) {
scale = eps;
}

// Cut off small scale using the provided eps threshold
if (scale_val < eps_threshold) {
float org_scale = scale_val;
scale_val = eps_threshold;
if (scale < eps) {
float org_scale = scale;
scale = eps;

// Adjust min and max based on new scale
if (min_val == 0.0) {
max_val = eps_threshold * float(qmax - qmin);
} else if (max_val == 0.0) {
min_val = -eps_threshold * float(qmax - qmin);
} else {
float amplifier = eps_threshold / org_scale;
min_val *= amplifier;
max_val *= amplifier;
// Adjust min and max based on new scale to maintain proper quantization range
if (minv == 0.0) {
maxv = eps * float(qmax - qmin);
} else if (maxv == 0.0) {
minv = -eps * float(qmax - qmin);
} else {
float amplifier = eps / org_scale;
minv *= amplifier;
maxv *= amplifier;
}
}

// Calculate zero_point (matching reference implementation)
float initial_zero_point = float(qmin) - round(minv / scale);
zp = int(clamp(initial_zero_point, float(qmin), float(qmax)));
} else { // symmetric -- centred
float scale_sym;
if (mapping_type == 1) { // SYM
float M = max(abs(minv), abs(maxv));
scale_sym = M / (float(qmax - qmin) * 0.5);
} else { // SYM_NO_CLIP
float smin = abs(minv) / max(abs(float(qmin)), 1.0); // Avoid division by zero
float smax = maxv / max(float(qmax), 1.0); // Avoid division by zero
scale_sym = max(smin, smax);
}
}

// Calculate zero point
float zero_point_from_min = float(qmin) - min_val / scale_val;
float zero_point_from_max = float(qmax) - max_val / scale_val;
float zero_point_from_min_error = abs(float(qmin)) - abs(min_val / scale_val);
float zero_point_from_max_error = abs(float(qmax)) - abs(max_val / scale_val);
float initial_zero_point = zero_point_from_min_error < zero_point_from_max_error
? zero_point_from_min
: zero_point_from_max;
// Handle zero or very small scale
if (scale_sym == 0.0 || isinf(1.0/scale_sym)) {
scale_sym = eps;
}

// Nudge zero point to integer
if (initial_zero_point < float(qmin)) {
zero_point_val = qmin;
} else if (initial_zero_point > float(qmax)) {
zero_point_val = qmax;
} else {
zero_point_val = int(round(initial_zero_point));
scale = max(scale_sym, eps);
zp = int((qmax + qmin + 1) >> 1); // mid-point – always fits
}
}

Expand Down
Loading
Loading