diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.glsl index 639fe312148..7234b50a3f5 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.glsl @@ -19,6 +19,8 @@ #define MAX_THREADS 256 ${define_active_storage_type(STORAGE)} + +${define_required_extensions(DTYPE)} ${define_required_extensions("int8")} #extension GL_EXT_control_flow_attributes : require @@ -126,8 +128,8 @@ void find_min_max_for_row(const int output_y) { const int X4 = div_4(input_sizes.x); // Initialize thread-local min/max - float local_min = 1e30; - float local_max = -1e30; + T local_min = T(1e30); + T local_max = T(-1e30); // Each thread processes elements along their assigned output_id with stride // NUM_WORKERS_PER_OUTPUT @@ -187,7 +189,7 @@ void main() { calculate_scale_and_zero_point( local_min, local_max, quant_min, quant_max, scale, zero_point); - scales_out[i] = scale; + scales_out[i] = T(scale); zps_out[i] = zero_point; } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.yaml b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.yaml index 1594bb574bd..5dbf3d7adaa 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.yaml @@ -14,5 +14,6 @@ choose_qparams_per_row: - VALUE: buffer DTYPE: - VALUE: float + - VALUE: half shader_variants: - NAME: choose_qparams_per_row diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.yaml index cb9cdc4a046..a252055ed40 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.yaml @@ -16,6 +16,7 @@ linear_dq8ca_q4gsw_tiled: generate_variant_forall: DTYPE: - VALUE: float + - VALUE: half shader_variants: - NAME: linear_dq8ca_q4gsw_tiled_texture3d_texture2d - NAME: linear_dq8ca_q4gsw_tiled_texture3d_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int4_compute.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int4_compute.glslh index 7bc7071ab1f..0a11ed6f482 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int4_compute.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int4_compute.glslh @@ -77,7 +77,7 @@ void accumulate_out_tile_with_int_accum_from_int4_weights( out_tile.data[m][n4] = fma(VEC4_T(accum_adjusted), - input_scale_m * weight_scales.data[n4], + VEC4_T(input_scale_m * weight_scales.data[n4]), out_tile.data[m][n4]); } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh index 68ac269e9d7..ca25e406ac1 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh @@ -75,7 +75,7 @@ void accumulate_out_tile_with_int_accum( input_zp_vec * weight_sums.data[n4] + accum.data[m][n4]; out_tile.data[m][n4] = fma(VEC4_T(accum_adjusted), - input_q_scale * weight_scales.data[0], + VEC4_T(input_q_scale * weight_scales.data[0]), out_tile.data[m][n4]); } } @@ -98,7 +98,7 @@ void accumulate_out_tile_with_int_accum( input_zp_vec * weight_sums.data[n4] + accum.data[m][n4]; out_tile.data[m][n4] = fma(VEC4_T(accum_adjusted), - input_q_scale * weight_scales.data[n4], + VEC4_T(input_q_scale * weight_scales.data[n4]), out_tile.data[m][n4]); out_tile.data[m][n4] += bias.data[n4]; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl index 1dff0017f30..67d9c100f68 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl @@ -16,6 +16,8 @@ ${define_active_storage_type(STORAGE)} +${define_required_extensions(DTYPE)} + #extension GL_EXT_control_flow_attributes : require layout(std430) buffer; @@ -85,7 +87,7 @@ void main() { } // Initialize thread-local min/max - T local_exp_sum = 0; + T local_exp_sum = T(0); const int context_len_aligned_down = context_len - mod_4(context_len); const int C4_limit = div_4(context_len_aligned_down); diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.yaml index 8abf50399e0..66ec030680e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.yaml @@ -14,5 +14,6 @@ sdpa_attn_weights_softmax: - VALUE: buffer DTYPE: - VALUE: float + - VALUE: half shader_variants: - NAME: sdpa_attn_weights_softmax diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index a72dd1d9b3b..7192204a141 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -418,6 +418,7 @@ def build_args_parser() -> argparse.ArgumentParser: help="Delegate more operators beyond DQLinear to the xnnpack backend. Requires -X or --xnnpack to be set.", ) parser.add_argument("-V", "--vulkan", action="store_true") + parser.add_argument("--vulkan-force-fp16", action="store_true") parser.add_argument("--mps", action="store_true") parser.add_argument("--coreml", action="store_true") parser.add_argument( @@ -885,6 +886,7 @@ def _to_edge_and_lower_llama( # noqa: C901 use_kv_cache: bool = False, embedding_quantize: Optional[str] = None, pt2e_quantize: Optional[str] = None, + vulkan_force_fp16: bool = False, coreml_ios: int = 15, coreml_quantize: Optional[str] = None, coreml_compute_units: str = "cpu_only", @@ -905,6 +907,7 @@ def _to_edge_and_lower_llama( # noqa: C901 get_vulkan_partitioner( dtype_override, enable_dynamic_shape, + vulkan_force_fp16, ) ) modelname = f"vulkan_{modelname}" @@ -1125,6 +1128,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 if llm_config.quantization.pt2e_quantize else None ), + vulkan_force_fp16=llm_config.backend.vulkan.force_fp16, coreml_ios=llm_config.backend.coreml.ios, coreml_quantize=( llm_config.backend.coreml.quantize.value diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index 04991c0e73e..d756d1886ad 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -426,6 +426,7 @@ class VulkanConfig: """ enabled: bool = False + force_fp16: bool = False @dataclass @@ -610,6 +611,8 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901 # Vulkan if hasattr(args, "vulkan"): llm_config.backend.vulkan.enabled = args.vulkan + if hasattr(args, "vulkan_force_fp16"): + llm_config.backend.vulkan.force_fp16 = args.vulkan_force_fp16 # QNN if hasattr(args, "qnn"): diff --git a/extension/llm/export/partitioner_lib.py b/extension/llm/export/partitioner_lib.py index 7b093a7f1a3..5fe220f7dd9 100644 --- a/extension/llm/export/partitioner_lib.py +++ b/extension/llm/export/partitioner_lib.py @@ -32,7 +32,9 @@ def get_xnnpack_partitioner(dynamic_quant_only_partitioner: bool = True): def get_vulkan_partitioner( - dtype_override: Optional[str] = None, enable_dynamic_shape: bool = False + dtype_override: Optional[str] = None, + enable_dynamic_shape: bool = False, + force_fp16: bool = False, ): assert ( dtype_override == "fp32" or dtype_override is None @@ -41,7 +43,9 @@ def get_vulkan_partitioner( VulkanPartitioner, ) - return VulkanPartitioner({"require_dynamic_shapes": enable_dynamic_shape}) + return VulkanPartitioner( + {"require_dynamic_shapes": enable_dynamic_shape, "force_fp16": force_fp16} + ) def get_mps_partitioner(use_kv_cache: bool = False):