diff --git a/backends/xnnpack/cmake/Dependencies.cmake b/backends/xnnpack/cmake/Dependencies.cmake index ce25f5cec22..62c8e7e0288 100644 --- a/backends/xnnpack/cmake/Dependencies.cmake +++ b/backends/xnnpack/cmake/Dependencies.cmake @@ -42,7 +42,10 @@ set(XNNPACK_ENABLE_AVX512VNNIGFNI OFF CACHE BOOL "" ) - +set(XNNPACK_ENABLE_ARM_SME2 + ON + CACHE BOOL "" +) if(EXECUTORCH_XNNPACK_ENABLE_KLEIDI) set(XNNPACK_ENABLE_KLEIDIAI ON diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index 606118bbd05..ad927ef8917 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -609,9 +609,6 @@ bool isQP8(const fb_xnnpack::XNNGraph* graph, const NodePtr node) { auto cvt_output_id = graph_node->output_id(); auto check_dtype = [graph](uint32_t id, DataType dtype) -> bool { - assert( - dtype == DataType::xnn_datatype_qdint8 || - dtype == DataType::xnn_datatype_qbint4); for (auto value : *graph->xvalues()) { if (value->xvalue_union_type() != fb_xnnpack::XValueUnion::XNNQuantizedTensorValue) { @@ -631,6 +628,12 @@ bool isQP8(const fb_xnnpack::XNNGraph* graph, const NodePtr node) { return false; } + // XNNPACK dtypes which have qp8 support. + const std::vector supported_filter_dtypes = { + DataType::xnn_datatype_qbint4, + DataType::xnn_datatype_qcint4, + DataType::xnn_datatype_qcint8}; + // Find if the convert output is going to the right linear node. // Assuming if we can find one valid linear node, then we can use QP8 // for all the linear nodes consuming this convert output. @@ -638,9 +641,10 @@ bool isQP8(const fb_xnnpack::XNNGraph* graph, const NodePtr node) { if (node->xnode_union_type() == fb_xnnpack::XNodeUnion::XNNFullyConnected) { auto linear_node = node->xnode_union_as_XNNFullyConnected(); if (linear_node->input1_id() == cvt_output_id) { - if (check_dtype( - linear_node->filter_id(), DataType::xnn_datatype_qbint4)) { - return true; + for (auto supported_filter_dtype : supported_filter_dtypes) { + if (check_dtype(linear_node->filter_id(), supported_filter_dtype)) { + return true; + } } } }