diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index d584cd128ec..2649ed5b154 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -68,7 +68,7 @@ def _is_float_tensor(node: Node): or not isinstance(node.meta["val"], FakeTensor) ): return False - return node.meta["val"].dtype == torch.float32 + return node.meta["val"].dtype in (torch.bfloat16, torch.float32) def _mark_nodes_as_annotated(nodes: List[Node]):