Skip to content

Commit

Permalink
[ao][fx] Enable observed -> quantized float for static quantized Mult…
Browse files Browse the repository at this point in the history
…iheadAttention (#95636)

Test Plan:
Sandcastle

cc andrewor14 any suggestions here?

Differential Revision: D43631794

Pull Request resolved: #95636
Approved by: https://github.com/andrewor14
  • Loading branch information
kev-zheng authored and pytorchmergebot committed Feb 28, 2023
1 parent fafb410 commit f1dbfe2
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
16 changes: 16 additions & 0 deletions torch/ao/quantization/fx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from .utils import (
_get_module,
_is_custom_module_lstm,
_is_custom_module_mha,
get_custom_module_class_keys,
create_getattr_from_value,
collect_producer_nodes,
Expand Down Expand Up @@ -814,6 +815,21 @@ def convert_custom_module(
_remove_previous_dequantize_in_custom_module(node, inputs, graph)
_remove_previous_dequantize_in_custom_module(node, hidden0, graph)
_remove_previous_dequantize_in_custom_module(node, hidden1, graph)
elif _is_custom_module_mha(node, modules):
# Inputs are in the form (query, key, value)
# TODO: This is the first step in enabling the full fx custom module
# quantization path for MultiheadAttention, and only covers the inputs
# to the module.
# Additional handling is yet to be implemented for the outputs, similar
# to LSTM custom module
assert len(node.args) == 3
query, key, value = node.args
assert isinstance(query, Node)
assert isinstance(key, Node)
assert isinstance(value, Node)
_remove_previous_dequantize_in_custom_module(node, query, graph)
_remove_previous_dequantize_in_custom_module(node, key, graph)
_remove_previous_dequantize_in_custom_module(node, value, graph)
else:
# remove the previous dequant node to ensure the inputs are quantized
arg = node.args[0]
Expand Down
19 changes: 19 additions & 0 deletions torch/ao/quantization/fx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,25 @@ def _is_custom_module_lstm(
else:
return isinstance(mod, torch.ao.nn.quantizable.LSTM)

def _is_custom_module_mha(
node: Node,
named_modules: Dict[str, torch.nn.Module],
qconfig: QConfigAny = None,
# QuantizeHandler, but we cannot include the type here due to circular imports
qhandler: Optional[Any] = None,
) -> bool:
"""
Return whether this refers to the custom module MultiheadAttention flow.
"""
mod = _get_module(node, named_modules)
if qconfig is not None and qhandler is not None:
assert isinstance(qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler) # type: ignore[attr-defined]
return isinstance(mod, torch.nn.MultiheadAttention) and \
activation_is_statically_quantized(qconfig) and \
qhandler.is_custom_module()
else:
return isinstance(mod, torch.ao.nn.quantizable.MultiheadAttention)

def _get_module(node: Node, named_modules: Dict[str, torch.nn.Module]) -> Optional[torch.nn.Module]:
"""
If `node` refers to a call_module node, return the module, else None.
Expand Down

0 comments on commit f1dbfe2

Please sign in to comment.