Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#define MAX_THREADS 256

${define_active_storage_type(STORAGE)}

${define_required_extensions(DTYPE)}
${define_required_extensions("int8")}

#extension GL_EXT_control_flow_attributes : require
Expand Down Expand Up @@ -126,8 +128,8 @@ void find_min_max_for_row(const int output_y) {
const int X4 = div_4(input_sizes.x);

// Initialize thread-local min/max
float local_min = 1e30;
float local_max = -1e30;
T local_min = T(1e30);
T local_max = T(-1e30);

// Each thread processes elements along their assigned output_id with stride
// NUM_WORKERS_PER_OUTPUT
Expand Down Expand Up @@ -187,7 +189,7 @@ void main() {
calculate_scale_and_zero_point(
local_min, local_max, quant_min, quant_max, scale, zero_point);

scales_out[i] = scale;
scales_out[i] = T(scale);
zps_out[i] = zero_point;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ choose_qparams_per_row:
- VALUE: buffer
DTYPE:
- VALUE: float
- VALUE: half
shader_variants:
- NAME: choose_qparams_per_row
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ linear_dq8ca_q4gsw_tiled:
generate_variant_forall:
DTYPE:
- VALUE: float
- VALUE: half
shader_variants:
- NAME: linear_dq8ca_q4gsw_tiled_texture3d_texture2d
- NAME: linear_dq8ca_q4gsw_tiled_texture3d_buffer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ void accumulate_out_tile_with_int_accum_from_int4_weights(

out_tile.data[m][n4] =
fma(VEC4_T(accum_adjusted),
input_scale_m * weight_scales.data[n4],
VEC4_T(input_scale_m * weight_scales.data[n4]),
out_tile.data[m][n4]);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ void accumulate_out_tile_with_int_accum(
input_zp_vec * weight_sums.data[n4] + accum.data[m][n4];
out_tile.data[m][n4] =
fma(VEC4_T(accum_adjusted),
input_q_scale * weight_scales.data[0],
VEC4_T(input_q_scale * weight_scales.data[0]),
out_tile.data[m][n4]);
}
}
Expand All @@ -98,7 +98,7 @@ void accumulate_out_tile_with_int_accum(
input_zp_vec * weight_sums.data[n4] + accum.data[m][n4];
out_tile.data[m][n4] =
fma(VEC4_T(accum_adjusted),
input_q_scale * weight_scales.data[n4],
VEC4_T(input_q_scale * weight_scales.data[n4]),
out_tile.data[m][n4]);
out_tile.data[m][n4] += bias.data[n4];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

${define_active_storage_type(STORAGE)}

${define_required_extensions(DTYPE)}

#extension GL_EXT_control_flow_attributes : require

layout(std430) buffer;
Expand Down Expand Up @@ -85,7 +87,7 @@ void main() {
}

// Initialize thread-local min/max
T local_exp_sum = 0;
T local_exp_sum = T(0);

const int context_len_aligned_down = context_len - mod_4(context_len);
const int C4_limit = div_4(context_len_aligned_down);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ sdpa_attn_weights_softmax:
- VALUE: buffer
DTYPE:
- VALUE: float
- VALUE: half
shader_variants:
- NAME: sdpa_attn_weights_softmax
4 changes: 4 additions & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ def build_args_parser() -> argparse.ArgumentParser:
help="Delegate more operators beyond DQLinear to the xnnpack backend. Requires -X or --xnnpack to be set.",
)
parser.add_argument("-V", "--vulkan", action="store_true")
parser.add_argument("--vulkan-force-fp16", action="store_true")
parser.add_argument("--mps", action="store_true")
parser.add_argument("--coreml", action="store_true")
parser.add_argument(
Expand Down Expand Up @@ -885,6 +886,7 @@ def _to_edge_and_lower_llama( # noqa: C901
use_kv_cache: bool = False,
embedding_quantize: Optional[str] = None,
pt2e_quantize: Optional[str] = None,
vulkan_force_fp16: bool = False,
coreml_ios: int = 15,
coreml_quantize: Optional[str] = None,
coreml_compute_units: str = "cpu_only",
Expand All @@ -905,6 +907,7 @@ def _to_edge_and_lower_llama( # noqa: C901
get_vulkan_partitioner(
dtype_override,
enable_dynamic_shape,
vulkan_force_fp16,
)
)
modelname = f"vulkan_{modelname}"
Expand Down Expand Up @@ -1125,6 +1128,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
if llm_config.quantization.pt2e_quantize
else None
),
vulkan_force_fp16=llm_config.backend.vulkan.force_fp16,
coreml_ios=llm_config.backend.coreml.ios,
coreml_quantize=(
llm_config.backend.coreml.quantize.value
Expand Down
3 changes: 3 additions & 0 deletions extension/llm/export/config/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ class VulkanConfig:
"""

enabled: bool = False
force_fp16: bool = False


@dataclass
Expand Down Expand Up @@ -610,6 +611,8 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
# Vulkan
if hasattr(args, "vulkan"):
llm_config.backend.vulkan.enabled = args.vulkan
if hasattr(args, "vulkan_force_fp16"):
llm_config.backend.vulkan.force_fp16 = args.vulkan_force_fp16

# QNN
if hasattr(args, "qnn"):
Expand Down
8 changes: 6 additions & 2 deletions extension/llm/export/partitioner_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def get_xnnpack_partitioner(dynamic_quant_only_partitioner: bool = True):


def get_vulkan_partitioner(
dtype_override: Optional[str] = None, enable_dynamic_shape: bool = False
dtype_override: Optional[str] = None,
enable_dynamic_shape: bool = False,
force_fp16: bool = False,
):
assert (
dtype_override == "fp32" or dtype_override is None
Expand All @@ -41,7 +43,9 @@ def get_vulkan_partitioner(
VulkanPartitioner,
)

return VulkanPartitioner({"require_dynamic_shapes": enable_dynamic_shape})
return VulkanPartitioner(
{"require_dynamic_shapes": enable_dynamic_shape, "force_fp16": force_fp16}
)


def get_mps_partitioner(use_kv_cache: bool = False):
Expand Down
Loading