Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion backends/xnnpack/cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ set(XNNPACK_ENABLE_AVX512VNNIGFNI
OFF
CACHE BOOL ""
)

set(XNNPACK_ENABLE_ARM_SME2
ON
CACHE BOOL ""
)
if(EXECUTORCH_XNNPACK_ENABLE_KLEIDI)
set(XNNPACK_ENABLE_KLEIDIAI
ON
Expand Down
16 changes: 10 additions & 6 deletions backends/xnnpack/runtime/XNNCompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -609,9 +609,6 @@ bool isQP8(const fb_xnnpack::XNNGraph* graph, const NodePtr node) {
auto cvt_output_id = graph_node->output_id();

auto check_dtype = [graph](uint32_t id, DataType dtype) -> bool {
assert(
dtype == DataType::xnn_datatype_qdint8 ||
dtype == DataType::xnn_datatype_qbint4);
for (auto value : *graph->xvalues()) {
if (value->xvalue_union_type() !=
fb_xnnpack::XValueUnion::XNNQuantizedTensorValue) {
Expand All @@ -631,16 +628,23 @@ bool isQP8(const fb_xnnpack::XNNGraph* graph, const NodePtr node) {
return false;
}

// XNNPACK dtypes which have qp8 support.
const std::vector<DataType> supported_filter_dtypes = {
DataType::xnn_datatype_qbint4,
DataType::xnn_datatype_qcint4,
DataType::xnn_datatype_qcint8};

// Find if the convert output is going to the right linear node.
// Assuming if we can find one valid linear node, then we can use QP8
// for all the linear nodes consuming this convert output.
for (auto node : *graph->xnodes()) {
if (node->xnode_union_type() == fb_xnnpack::XNodeUnion::XNNFullyConnected) {
auto linear_node = node->xnode_union_as_XNNFullyConnected();
if (linear_node->input1_id() == cvt_output_id) {
if (check_dtype(
linear_node->filter_id(), DataType::xnn_datatype_qbint4)) {
return true;
for (auto supported_filter_dtype : supported_filter_dtypes) {
if (check_dtype(linear_node->filter_id(), supported_filter_dtype)) {
return true;
}
}
}
}
Expand Down
Loading