-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Quant][fx] Enable FX static quantization for LSTM #85068
Conversation
Summary: Static quantization for LSTM was previously only supported in eager mode. This is because we had to go through the custom module flow(torch.ao.nn.quantizable.LSTM). However, this flow returned multiple outputs, and this is currently not supported in FX graph mode quantization. In this commit, we work around this limitation by making two important modifications to this flow: (1) First, the output of the LSTM node is in the form (result, (hidden0, hidden1)), and each of the three internal nodes was already observed, so there is actually no need to observe the output again. This commit removes the extra output observer that tried to observe the whole tuple, which failed previously. (2) Second, the inputs of the LSTM node must be quantized, and the outputs must be dequantized. This was not handled correctly previously because FX graph mode quantization did not had special logic for tuples. For the output in particular, we had to insert ops to manually split the tuple, insert dequantize nodes, and recombine them into the original format. Note that the changes here are intended only as temporary hacks to enable this flow. In the future, we should add better support for dtype inference and tuple handling in FX graph mode quantization. Test Plan: python test/test_quantization.py TestQuantizeFx.test_static_lstm Reviewers: jerryzh168, vkuzo Subscribers: jerryzh168, vkuzo [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/85068
Note: Links to docs will display an error until the docs builds have been completed. ✅ No Failures, 1 PendingAs of commit ad4ddce: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Summary: Static quantization for LSTM was previously only supported in eager mode. This is because we had to go through the custom module flow(torch.ao.nn.quantizable.LSTM). However, this flow returned multiple outputs, and this is currently not supported in FX graph mode quantization. In this commit, we work around this limitation by making two important modifications to this flow: (1) First, the output of the LSTM node is in the form (result, (hidden0, hidden1)), and each of the three internal nodes was already observed, so there is actually no need to observe the output again. This commit removes the extra output observer that tried to observe the whole tuple, which failed previously. (2) Second, the inputs of the LSTM node must be quantized, and the outputs must be dequantized. This was not handled correctly previously because FX graph mode quantization did not had special logic for tuples. For the output in particular, we had to insert ops to manually split the tuple, insert dequantize nodes, and recombine them into the original format. Note that the changes here are intended only as temporary hacks to enable this flow. In the future, we should add better support for dtype inference and tuple handling in FX graph mode quantization. Test Plan: python test/test_quantization.py TestQuantizeFx.test_static_lstm Reviewers: jerryzh168, vkuzo Subscribers: jerryzh168, vkuzo ghstack-source-id: 79ba2fad7bbc8a248b4acd86ade4ecbc36da38f9 Pull Request resolved: #85068
Summary: Static quantization for LSTM was previously only supported in eager mode. This is because we had to go through the custom module flow(torch.ao.nn.quantizable.LSTM). However, this flow returned multiple outputs, and this is currently not supported in FX graph mode quantization. In this commit, we work around this limitation by making two important modifications to this flow: (1) First, the output of the LSTM node is in the form (result, (hidden0, hidden1)), and each of the three internal nodes was already observed, so there is actually no need to observe the output again. This commit removes the extra output observer that tried to observe the whole tuple, which failed previously. (2) Second, the inputs of the LSTM node must be quantized, and the outputs must be dequantized. This was not handled correctly previously because FX graph mode quantization did not had special logic for tuples. For the output in particular, we had to insert ops to manually split the tuple, insert dequantize nodes, and recombine them into the original format. Note that the changes here are intended only as temporary hacks to enable this flow. In the future, we should add better support for dtype inference and tuple handling in FX graph mode quantization. Test Plan: python test/test_quantization.py TestQuantizeFx.test_static_lstm Reviewers: jerryzh168, vkuzo Subscribers: jerryzh168, vkuzo ghstack-source-id: 79ba2fad7bbc8a248b4acd86ade4ecbc36da38f9 Pull Request resolved: #85068
|
||
m = MyModel() | ||
qconfig_mapping = get_default_qconfig_mapping() | ||
prepare_custom_config = PrepareCustomConfig() \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe we can have a get_default_prepare_custom_config as well (can be in a separate PR)
Summary: Static quantization for LSTM was previously only supported in eager mode. This is because we had to go through the custom module flow(torch.ao.nn.quantizable.LSTM). However, this flow returned multiple outputs, and this is currently not supported in FX graph mode quantization. In this commit, we work around this limitation by making two important modifications to this flow: (1) First, the output of the LSTM node is in the form (result, (hidden0, hidden1)), and each of the three internal nodes was already observed, so there is actually no need to observe the output again. This commit removes the extra output observer that tried to observe the whole tuple, which failed previously. (2) Second, the inputs of the LSTM node must be quantized, and the outputs must be dequantized. This was not handled correctly previously because FX graph mode quantization did not had special logic for tuples. For the output in particular, we had to insert ops to manually split the tuple, insert dequantize nodes, and recombine them into the original format. Note that the changes here are intended only as temporary hacks to enable this flow. In the future, we should add better support for dtype inference and tuple handling in FX graph mode quantization. Test Plan: python test/test_quantization.py TestQuantizeFx.test_static_lstm Reviewers: jerryzh168, vkuzo Subscribers: jerryzh168, vkuzo [ghstack-poisoned]
Summary: Static quantization for LSTM was previously only supported in eager mode. This is because we had to go through the custom module flow(torch.ao.nn.quantizable.LSTM). However, this flow returned multiple outputs, and this is currently not supported in FX graph mode quantization. In this commit, we work around this limitation by making two important modifications to this flow: (1) First, the output of the LSTM node is in the form (result, (hidden0, hidden1)), and each of the three internal nodes was already observed, so there is actually no need to observe the output again. This commit removes the extra output observer that tried to observe the whole tuple, which failed previously. (2) Second, the inputs of the LSTM node must be quantized, and the outputs must be dequantized. This was not handled correctly previously because FX graph mode quantization did not had special logic for tuples. For the output in particular, we had to insert ops to manually split the tuple, insert dequantize nodes, and recombine them into the original format. Note that the changes here are intended only as temporary hacks to enable this flow. In the future, we should add better support for dtype inference and tuple handling in FX graph mode quantization. Test Plan: python test/test_quantization.py TestQuantizeFx.test_static_lstm Reviewers: jerryzh168, vkuzo Subscribers: jerryzh168, vkuzo ghstack-source-id: 65eb10fb9fe3729c3863ffa393ea3ecf0b34e4f4 Pull Request resolved: #85068
Summary: Static quantization for LSTM was previously only supported in eager mode. This is because we had to go through the custom module flow(torch.ao.nn.quantizable.LSTM). However, this flow returned multiple outputs, and this is currently not supported in FX graph mode quantization. In this commit, we work around this limitation by making two important modifications to this flow: (1) First, the output of the LSTM node is in the form (result, (hidden0, hidden1)), and each of the three internal nodes was already observed, so there is actually no need to observe the output again. This commit removes the extra output observer that tried to observe the whole tuple, which failed previously. (2) Second, the inputs of the LSTM node must be quantized, and the outputs must be dequantized. This was not handled correctly previously because FX graph mode quantization did not had special logic for tuples. For the output in particular, we had to insert ops to manually split the tuple, insert dequantize nodes, and recombine them into the original format. Note that the changes here are intended only as temporary hacks to enable this flow. In the future, we should add better support for dtype inference and tuple handling in FX graph mode quantization. Test Plan: python test/test_quantization.py TestQuantizeFx.test_static_lstm Reviewers: jerryzh168, vkuzo Subscribers: jerryzh168, vkuzo ghstack-source-id: 65eb10fb9fe3729c3863ffa393ea3ecf0b34e4f4 Pull Request resolved: #85068
Summary: Static quantization for LSTM was previously only supported in eager mode. This is because we had to go through the custom module flow(torch.ao.nn.quantizable.LSTM). However, this flow returned multiple outputs, and this is currently not supported in FX graph mode quantization. In this commit, we work around this limitation by making two important modifications to this flow: (1) First, the output of the LSTM node is in the form (result, (hidden0, hidden1)), and each of the three internal nodes was already observed, so there is actually no need to observe the output again. This commit removes the extra output observer that tried to observe the whole tuple, which failed previously. (2) Second, the inputs of the LSTM node must be quantized, and the outputs must be dequantized. This was not handled correctly previously because FX graph mode quantization did not had special logic for tuples. For the output in particular, we had to insert ops to manually split the tuple, insert dequantize nodes, and recombine them into the original format. Note that the changes here are intended only as temporary hacks to enable this flow. In the future, we should add better support for dtype inference and tuple handling in FX graph mode quantization. Test Plan: python test/test_quantization.py TestQuantizeFx.test_static_lstm Reviewers: jerryzh168, vkuzo Subscribers: jerryzh168, vkuzo ghstack-source-id: 65eb10fb9fe3729c3863ffa393ea3ecf0b34e4f4 Pull Request resolved: #85068
Summary: Static quantization for LSTM was previously only supported in eager mode. This is because we had to go through the custom module flow(torch.ao.nn.quantizable.LSTM). However, this flow returned multiple outputs, and this is currently not supported in FX graph mode quantization. In this commit, we work around this limitation by making two important modifications to this flow: (1) First, the output of the LSTM node is in the form (result, (hidden0, hidden1)), and each of the three internal nodes was already observed, so there is actually no need to observe the output again. This commit removes the extra output observer that tried to observe the whole tuple, which failed previously. (2) Second, the inputs of the LSTM node must be quantized, and the outputs must be dequantized. This was not handled correctly previously because FX graph mode quantization did not had special logic for tuples. For the output in particular, we had to insert ops to manually split the tuple, insert dequantize nodes, and recombine them into the original format. Note that the changes here are intended only as temporary hacks to enable this flow. In the future, we should add better support for dtype inference and tuple handling in FX graph mode quantization. Test Plan: python test/test_quantization.py TestQuantizeFx.test_static_lstm Reviewers: jerryzh168, vkuzo Subscribers: jerryzh168, vkuzo ghstack-source-id: 65eb10fb9fe3729c3863ffa393ea3ecf0b34e4f4 Pull Request resolved: #85068
Summary: Static quantization for LSTM was previously only supported in eager mode. This is because we had to go through the custom module flow(torch.ao.nn.quantizable.LSTM). However, this flow returned multiple outputs, and this is currently not supported in FX graph mode quantization. In this commit, we work around this limitation by making two important modifications to this flow: (1) First, the output of the LSTM node is in the form (result, (hidden0, hidden1)), and each of the three internal nodes was already observed, so there is actually no need to observe the output again. This commit removes the extra output observer that tried to observe the whole tuple, which failed previously. (2) Second, the inputs of the LSTM node must be quantized, and the outputs must be dequantized. This was not handled correctly previously because FX graph mode quantization did not had special logic for tuples. For the output in particular, we had to insert ops to manually split the tuple, insert dequantize nodes, and recombine them into the original format. Note that the changes here are intended only as temporary hacks to enable this flow. In the future, we should add better support for dtype inference and tuple handling in FX graph mode quantization. Test Plan: python test/test_quantization.py TestQuantizeFx.test_static_lstm Reviewers: jerryzh168, vkuzo Subscribers: jerryzh168, vkuzo ghstack-source-id: 65eb10fb9fe3729c3863ffa393ea3ecf0b34e4f4 Pull Request resolved: #85068
Summary: (TODO: write this) Test Plan: python test/test_quantization.py TestQuantizeFx.test_static_lstm Reviewers: jerryzh168, vkuzo Subscribers: jerryzh168, vkuzo ghstack-source-id: 65eb10fb9fe3729c3863ffa393ea3ecf0b34e4f4 Pull Request resolved: #85068
Summary: (TODO: write this) Test Plan: python test/test_quantization.py TestQuantizeFx.test_static_lstm Reviewers: jerryzh168, vkuzo Subscribers: jerryzh168, vkuzo ghstack-source-id: 65eb10fb9fe3729c3863ffa393ea3ecf0b34e4f4 Pull Request resolved: #85068
Summary: (TODO: write this) Test Plan: python test/test_quantization.py TestQuantizeFx.test_static_lstm Reviewers: jerryzh168, vkuzo Subscribers: jerryzh168, vkuzo ghstack-source-id: 65eb10fb9fe3729c3863ffa393ea3ecf0b34e4f4 Pull Request resolved: #85068
Summary: (TODO: write this) Test Plan: python test/test_quantization.py TestQuantizeFx.test_static_lstm Reviewers: jerryzh168, vkuzo Subscribers: jerryzh168, vkuzo ghstack-source-id: 65eb10fb9fe3729c3863ffa393ea3ecf0b34e4f4 Pull Request resolved: #85068
Summary: (TODO: write this) Test Plan: python test/test_quantization.py TestQuantizeFx.test_static_lstm Reviewers: jerryzh168, vkuzo Subscribers: jerryzh168, vkuzo ghstack-source-id: 65eb10fb9fe3729c3863ffa393ea3ecf0b34e4f4 Pull Request resolved: #85068
Summary: (TODO: write this) Test Plan: python test/test_quantization.py TestQuantizeFx.test_static_lstm python test/test_quantization.py TestQuantizeFx.test_static_lstm_consume_tuple Reviewers: jerryzh168, vkuzo Subscribers: jerryzh168, vkuzo ghstack-source-id: 65eb10fb9fe3729c3863ffa393ea3ecf0b34e4f4 Pull Request resolved: #85068
**Summary:** This commit enables the custom module LSTM path for FX graph mode static quantization. This has the same flow as eager mode, which was already previously supported: ``` torch.nn.LSTM | (prepare_fx) v torch.ao.nn.quantizable.LSTM | (convert_fx) v torch.ao.nn.quantized.LSTM ``` Context: Today, in FX graph mode static quantization, custom modules are assumed to have quantized inputs and quantized outputs, with the exact dtypes derived from the associated QConfig (default quint8). Since custom modules are currently not handled through the reference model flow, their observer replacement logic are a little different: ``` \# (1) Original model input -> custom_module -> output \# (2) Observed model (after prepare) input -> obs0 -> custom_module -> obs1 -> output \# (3) Quantized model (after convert) input -> quant -> quantized_custom_module -> dequant -> output ``` In the last step, input observers are replaced with "quantize" and output observers are replaced with "dequantize", in contrast to other non-custom-module patterns where observers are replaced with "quantize-dequantize" pairs instead. **Test Plan:** python test/test_quantization.py TestQuantizeFx.test_static_lstm python test/test_quantization.py TestQuantizeFx.test_static_lstm_consume_tuple **Reviewers:** jerryzh168, vkuzo **Subscribers:** jerryzh168, vkuzo ghstack-source-id: 65eb10fb9fe3729c3863ffa393ea3ecf0b34e4f4 Pull Request resolved: #85068
**Summary:** This commit enables the custom module LSTM path for FX graph mode static quantization. This has the same flow as eager mode, which was already previously supported: ``` torch.nn.LSTM | (prepare_fx) v torch.ao.nn.quantizable.LSTM | (convert_fx) v torch.ao.nn.quantized.LSTM ``` **Context:** Today, in FX graph mode static quantization, custom modules are assumed to have quantized inputs and quantized outputs, with the exact dtypes derived from the associated QConfig (default quint8). Since custom modules are currently not handled through the reference model flow, their observer replacement logic are a little different: ``` \# (1) Original model input -> custom_module -> output \# (2) Observed model (after prepare) input -> obs0 -> custom_module -> obs1 -> output \# (3) Quantized model (after convert) input -> quant -> quantized_custom_module -> dequant -> output ``` In the last step, input observers are replaced with "quantize" and output observers are replaced with "dequantize", in contrast to other non-custom-module patterns where observers are replaced with "quantize-dequantize" pairs instead. Note that, conceptually, the output observer `obs1` is really just a `DeQuantStub`. **Custom module LSTM:** The reason why custom module LSTM cannot be handled the same way is because, unlike other custom modules, its inputs and outputs are *nested tuples* instead of single tensors. This is how the existing custom module code would try to handle LSTMs: ``` \# (1) Original model \# input format: (input, (hidden0, hidden1)) \# output format: (output, (hidden0, hidden1)) input -> lstm -> output hidden0 -/ \-> hidden0 hidden1 -/ \-> hidden1 \# (2) Observed model (after prepare) input -> obs0 -> lstm -> obs1 # fails hidden0 -/ # missing observer hidden1 -/ # missing observer ``` However, this fails today because 1) we assume there is only one input to the custom module, and so we never end up quantizing `hidden0` and `hidden1`, and 2) the output observer `obs1` is fed a tuple, which it does not understand how to handle. The ideal fix for this would be to design a more general QConfig that allows users to specify complex input and output formats. This would enable FX graph mode quantization to understand arbitrary nested structures and automatically infer how to transform the graph accordingly. Until this is available, users will have to rely on the short-term fix introduced in this commit. **Short-term fix:** This commit addresses the above by specifically handling the input and output structures used by custom module LSTM. For the inputs, we manually insert observers for `hidden0` and `hidden1` to ensure all input tensors are quantized. For the outputs, we split the tuple into its internal nodes, attach a `DeQuantStub` to each node, and recombine these `DeQuantStub`s according to the original structure. Finally, we must also reroute consumers of the original LSTM tuple (and its internal nodes, e.g. lstm[0]) to these `DeQuantStub`s. ``` \# (1) Original model input -> lstm -> output -> linear0 hidden0 -/ \-> hidden0 -> linear1 hidden1 -/ \-> hidden1 -> linear2 \# (2) Observed model (after prepare) input -> obs0 -> lstm -> dqstub -> output -> linear0 -> obs3 hidden0 -> obs1 -/ \-> dqstub -> hidden0 -> linear1 -> obs4 hidden1 -> obs2 -/ \-> dqstub -> hidden1 -> linear2 -> obs5 \# (3) Reference model (after convert) input -> quant -> qlstm -> dequant -> linear0 -> quant -> dequant hidden0 -> quant -/ \-> dequant -> linear1 -> quant -> dequant hidden1 -> quant -/ \-> dequant -> linear2 -> quant -> dequant \# (4) Quantized model (after lowering) input -> quant -> qlstm -> quantized_linear0 -> dequant hidden0 -> quant -/ \-> quantized_linear1 -> dequant hidden1 -> quant -/ \-> quantized_linear2 -> dequant ``` Note that we choose to insert `DeQuantStub`s here instead of observers because these will ultimately be replaced by "dequantize" nodes. This matches the general custom module behavior, where output observers are replaced only with "dequantize" nodes (as opposed to the normal "quantize-dequantize" pair) since custom module outputs are assumed to already be quantized. In the future, we should use `DeQuantStub`s in place of output observers for custom modules in general. **Test plan:** python test/test_quantization.py TestQuantizeFx.test_static_lstm python test/test_quantization.py TestQuantizeFx.test_static_lstm_consume_tuple **Reviewers:** jerryzh168, vkuzo **Subscribers:** jerryzh168, vkuzo ghstack-source-id: 65eb10fb9fe3729c3863ffa393ea3ecf0b34e4f4 Pull Request resolved: #85068
**Summary:** This commit enables the custom module LSTM path for FX graph mode static quantization. This has the same flow as eager mode, which was already previously supported: ``` torch.nn.LSTM | (prepare_fx) v torch.ao.nn.quantizable.LSTM | (convert_fx) v torch.ao.nn.quantized.LSTM ``` The main reason why custom module LSTM is not supported in FX graph mode quantization today is because its inputs and outputs are nested tuples, and existing constructs such as observers, "quantize" nodes, and "dequantize" nodes do not understand how to handle complex structures. Note that the approach taken in this commit is only intended to be a short-term solution highly tailored to the input and output formats of custom module LSTM. In the future, for the longer-term solution, we should design a more general QConfig that allows users to specify complex input and output formats, and enable FX graph mode quantization to understand arbitrary nested structures and automatically infer how to transform the graph accordingly. **Context:** Today, in FX graph mode static quantization, custom modules are assumed to have quantized inputs and quantized outputs, with the exact dtypes derived from the associated QConfig (default quint8). Since custom modules are currently not handled through the reference model flow, their observer replacement logic are a little different from normal operators: ``` \# (1) Original model input -> custom_module -> output \# (2) Observed model (after prepare) input -> obs0 -> custom_module -> obs1 -> output \# (3) Quantized model (after convert) input -> quant -> quantized_custom_module -> dequant -> output ``` In the last step, input observers are replaced with "quantize" and output observers are replaced with "dequantize", in contrast to other non-custom-module patterns where observers are replaced with "quantize-dequantize" pairs instead. Note that, conceptually, the output observer `obs1` is really just a DeQuantStub, since no observation is actually needed. **Custom module LSTM:** The reason why custom module LSTM cannot be handled in the same way is because, unlike other custom modules, its inputs and outputs are nested tuples instead of single tensors. This is how the existing custom module code would try to handle LSTMs: ``` \# (1) Original model \# input format: (input, (hidden0, hidden1)) \# output format: (output, (hidden0, hidden1)) input -> lstm -> output hidden0 -/ \-> hidden0 hidden1 -/ \-> hidden1 \# (2) Observed model (after prepare) input -> obs0 -> lstm -> obs1 # fails hidden0 -/ # missing observer hidden1 -/ # missing observer ``` However, this fails today because 1) we assume there is only one input to the custom module, and so we never end up quantizing `hidden0` and `hidden1`, and 2) the output observer `obs1` is fed a tuple, which it does not understand how to handle. **Short-term fix:** This commit addresses the above by specifically handling the input and output structures used by custom module LSTM. For the inputs, we manually insert observers for `hidden0` and `hidden1` to ensure all input tensors are quantized. For the outputs, we split the tuple into its internal nodes, attach a DeQuantStub to each node, and recombine these DeQuantStubs according to the original structure. Finally, we must also reroute consumers of the original LSTM tuple (and its internal nodes, e.g. `lstm[0]`) to these DeQuantStubs: ``` \# (1) Original model input -> lstm -> output -> linear0 hidden0 -/ \-> hidden0 -> linear1 hidden1 -/ \-> hidden1 -> linear2 \# (2) Observed model (after prepare) input -> obs0 -> lstm -> output -> dqstub -> linear0 -> obs3 hidden0 -> obs1 -/ \-> hidden0 -> dqstub -> linear1 -> obs4 hidden1 -> obs2 -/ \-> hidden1 -> dqstub -> linear2 -> obs5 \# (3) Reference model (after convert) input -> quant -> qlstm -> output -> dequant -> linear0 -> quant -> dequant hidden0 -> quant -/ \-> hidden0 -> dequant -> linear1 -> quant -> dequant hidden1 -> quant -/ \-> hidden1 -> dequant -> linear2 -> quant -> dequant \# (4) Quantized model (after lowering) input -> quant -> qlstm -> output -> quantized_linear0 -> dequant hidden0 -> quant -/ \-> hidden0 -> quantized_linear1 -> dequant hidden1 -> quant -/ \-> hidden1 -> quantized_linear2 -> dequant ``` Note that we choose to insert DeQuantStubs here instead of observers because these will ultimately be replaced by "dequantize" nodes. This matches the general custom module behavior, where output observers are replaced only with "dequantize" nodes (as opposed to the normal "quantize-dequantize" pair), since custom module outputs are assumed to already be quantized. Using DeQuantStubs instead of observers also simplifies the "dequantize" insertion logic. In the future, we should use DeQuantStubs in place of output observers for custom modules in general. **Test plan:** python test/test_quantization.py TestQuantizeFx.test_static_lstm python test/test_quantization.py TestQuantizeFx.test_static_lstm_consume_tuple **Reviewers:** jerryzh168, vkuzo **Subscribers:** jerryzh168, vkuzo ghstack-source-id: e6a596fed38a7b0a0add4a5b31b26c0c5222c02f Pull Request resolved: #85068
**Summary:** This commit enables the custom module LSTM path for FX graph mode static quantization. This has the same flow as eager mode, which was already previously supported: ``` torch.nn.LSTM | (prepare_fx) v torch.ao.nn.quantizable.LSTM | (convert_fx) v torch.ao.nn.quantized.LSTM ``` The main reason why custom module LSTM is not supported in FX graph mode quantization today is because its inputs and outputs are nested tuples, and existing constructs such as observers, "quantize" nodes, and "dequantize" nodes do not understand how to handle complex structures. Note that the approach taken in this commit is only intended to be a short-term solution highly tailored to the input and output formats of custom module LSTM. In the future, for the longer-term solution, we should design a more general QConfig that allows users to specify complex input and output formats, and enable FX graph mode quantization to understand arbitrary nested structures and automatically infer how to transform the graph accordingly. **Context:** Today, in FX graph mode static quantization, custom modules are assumed to have quantized inputs and quantized outputs, with the exact dtypes derived from the associated QConfig (default quint8). Since custom modules are currently not handled through the reference model flow, their observer replacement logic are a little different from normal operators: ``` \# (1) Original model input -> custom_module -> output \# (2) Observed model (after prepare) input -> obs0 -> custom_module -> obs1 -> output \# (3) Quantized model (after convert) input -> quant -> quantized_custom_module -> dequant -> output ``` In the last step, input observers are replaced with "quantize" and output observers are replaced with "dequantize", in contrast to other non-custom-module patterns where observers are replaced with "quantize-dequantize" pairs instead. Note that, conceptually, the output observer `obs1` is really just a DeQuantStub, since no observation is actually needed. **Custom module LSTM:** The reason why custom module LSTM cannot be handled in the same way is because, unlike other custom modules, its inputs and outputs are nested tuples instead of single tensors. This is how the existing custom module code would try to handle LSTMs: ``` \# (1) Original model \# input format: (input, (hidden0, hidden1)) \# output format: (output, (hidden0, hidden1)) input -> lstm -> output hidden0 -/ \-> hidden0 hidden1 -/ \-> hidden1 \# (2) Observed model (after prepare) input -> obs0 -> lstm -> obs1 # fails hidden0 -/ # missing observer hidden1 -/ # missing observer ``` However, this fails today because 1) we assume there is only one input to the custom module, and so we never end up quantizing `hidden0` and `hidden1`, and 2) the output observer `obs1` is fed a tuple, which it does not understand how to handle. **Short-term fix:** This commit addresses the above by specifically handling the input and output structures used by custom module LSTM. For the inputs, we manually insert observers for `hidden0` and `hidden1` to ensure all input tensors are quantized. For the outputs, we split the tuple into its internal nodes, attach a DeQuantStub to each node, and recombine these DeQuantStubs according to the original structure. Finally, we must also reroute consumers of the original LSTM tuple (and its internal nodes, e.g. `lstm[0]`) to these DeQuantStubs: ``` \# (1) Original model input -> lstm -> output -> linear0 hidden0 -/ \-> hidden0 -> linear1 hidden1 -/ \-> hidden1 -> linear2 \# (2) Observed model (after prepare) input -> obs0 -> lstm -> output -> dqstub -> linear0 -> obs3 hidden0 -> obs1 -/ \-> hidden0 -> dqstub -> linear1 -> obs4 hidden1 -> obs2 -/ \-> hidden1 -> dqstub -> linear2 -> obs5 \# (3) Reference model (after convert) input -> quant -> qlstm -> output -> dequant -> linear0 -> quant -> dequant hidden0 -> quant -/ \-> hidden0 -> dequant -> linear1 -> quant -> dequant hidden1 -> quant -/ \-> hidden1 -> dequant -> linear2 -> quant -> dequant \# (4) Quantized model (after lowering) input -> quant -> qlstm -> output -> quantized_linear0 -> dequant hidden0 -> quant -/ \-> hidden0 -> quantized_linear1 -> dequant hidden1 -> quant -/ \-> hidden1 -> quantized_linear2 -> dequant ``` Note that we choose to insert DeQuantStubs here instead of observers because these will ultimately be replaced by "dequantize" nodes. This matches the general custom module behavior, where output observers are replaced only with "dequantize" nodes (as opposed to the normal "quantize-dequantize" pair), since custom module outputs are assumed to already be quantized. Using DeQuantStubs instead of observers also simplifies the "dequantize" insertion logic. In the future, we should use DeQuantStubs in place of output observers for custom modules in general. **Test plan:** python test/test_quantization.py TestQuantizeFx.test_static_lstm python test/test_quantization.py TestQuantizeFx.test_static_lstm_consume_tuple **Reviewers:** jerryzh168, vkuzo **Subscribers:** jerryzh168, vkuzo ghstack-source-id: e6a596fed38a7b0a0add4a5b31b26c0c5222c02f Pull Request resolved: #85068
**Summary:** This commit enables the custom module LSTM path for FX graph mode static quantization. This has the same flow as eager mode, which was already previously supported: ``` torch.nn.LSTM | (prepare_fx) v torch.ao.nn.quantizable.LSTM | (convert_fx) v torch.ao.nn.quantized.LSTM ``` The main reason why custom module LSTM is not supported in FX graph mode quantization today is because its inputs and outputs are nested tuples, and existing constructs such as observers, "quantize" nodes, and "dequantize" nodes do not understand how to handle complex structures. Note that the approach taken in this commit is only intended to be a short-term solution highly tailored to the input and output formats of custom module LSTM. In the future, for the longer-term solution, we should design a more general QConfig that allows users to specify complex input and output formats, and enable FX graph mode quantization to understand arbitrary nested structures and automatically infer how to transform the graph accordingly. **Context:** Today, in FX graph mode static quantization, custom modules are assumed to have quantized inputs and quantized outputs, with the exact dtypes derived from the associated QConfig (default quint8). Since custom modules are currently not handled through the reference model flow, their observer replacement logic are a little different from normal operators: ``` # (1) Original model input -> custom_module -> output # (2) Observed model (after prepare) input -> obs0 -> custom_module -> obs1 -> output # (3) Quantized model (after convert) input -> quant -> quantized_custom_module -> dequant -> output ``` In the last step, input observers are replaced with "quantize" and output observers are replaced with "dequantize", in contrast to other non-custom-module patterns where observers are replaced with "quantize-dequantize" pairs instead. Note that, conceptually, the output observer `obs1` is really just a DeQuantStub, since no observation is actually needed. **Custom module LSTM:** The reason why custom module LSTM cannot be handled in the same way is because, unlike other custom modules, its inputs and outputs are nested tuples instead of single tensors. This is how the existing custom module code would try to handle LSTMs: ``` # (1) Original model # input format: (input, (hidden0, hidden1)) # output format: (output, (hidden0, hidden1)) input -> lstm -> output hidden0 -/ \-> hidden0 hidden1 -/ \-> hidden1 # (2) Observed model (after prepare) input -> obs0 -> lstm -> obs1 # fails hidden0 -/ # missing observer hidden1 -/ # missing observer ``` However, this fails today because 1) we assume there is only one input to the custom module, and so we never end up quantizing `hidden0` and `hidden1`, and 2) the output observer `obs1` is fed a tuple, which it does not understand how to handle. **Short-term fix:** This commit addresses the above by specifically handling the input and output structures used by custom module LSTM. For the inputs, we manually insert observers for `hidden0` and `hidden1` to ensure all input tensors are quantized. For the outputs, we split the tuple into its internal nodes, attach a DeQuantStub to each node, and recombine these DeQuantStubs according to the original structure. Finally, we must also reroute consumers of the original LSTM tuple (and its internal nodes, e.g. `lstm[0]`) to these DeQuantStubs: ``` # (1) Original model input -> lstm -> output -> linear0 hidden0 -/ \-> hidden0 -> linear1 hidden1 -/ \-> hidden1 -> linear2 # (2) Observed model (after prepare) input -> obs0 -> lstm -> output -> dqstub -> linear0 -> obs3 hidden0 -> obs1 -/ \-> hidden0 -> dqstub -> linear1 -> obs4 hidden1 -> obs2 -/ \-> hidden1 -> dqstub -> linear2 -> obs5 # (3) Reference model (after convert) input -> quant -> qlstm -> output -> dequant -> linear0 -> quant -> dequant hidden0 -> quant -/ \-> hidden0 -> dequant -> linear1 -> quant -> dequant hidden1 -> quant -/ \-> hidden1 -> dequant -> linear2 -> quant -> dequant # (4) Quantized model (after lowering) input -> quant -> qlstm -> output -> quantized_linear0 -> dequant hidden0 -> quant -/ \-> hidden0 -> quantized_linear1 -> dequant hidden1 -> quant -/ \-> hidden1 -> quantized_linear2 -> dequant ``` Note that we choose to insert DeQuantStubs here instead of observers because these will ultimately be replaced by "dequantize" nodes. This matches the general custom module behavior, where output observers are replaced only with "dequantize" nodes (as opposed to the normal "quantize-dequantize" pair), since custom module outputs are assumed to already be quantized. Using DeQuantStubs instead of observers also simplifies the "dequantize" insertion logic. In the future, we should use DeQuantStubs in place of output observers for custom modules in general. **Test plan:** python test/test_quantization.py TestQuantizeFx.test_static_lstm python test/test_quantization.py TestQuantizeFx.test_static_lstm_consume_tuple **Reviewers:** jerryzh168, vkuzo **Subscribers:** jerryzh168, vkuzo [ghstack-poisoned]
**Summary:** This commit enables the custom module LSTM path for FX graph mode static quantization. This has the same flow as eager mode, which was already previously supported: ``` torch.nn.LSTM | (prepare_fx) v torch.ao.nn.quantizable.LSTM | (convert_fx) v torch.ao.nn.quantized.LSTM ``` The main reason why custom module LSTM is not supported in FX graph mode quantization today is because its inputs and outputs are nested tuples, and existing constructs such as observers, "quantize" nodes, and "dequantize" nodes do not understand how to handle complex structures. Note that the approach taken in this commit is only intended to be a short-term solution highly tailored to the input and output formats of custom module LSTM. In the future, for the longer-term solution, we should design a more general QConfig that allows users to specify complex input and output formats, and enable FX graph mode quantization to understand arbitrary nested structures and automatically infer how to transform the graph accordingly. **Context:** Today, in FX graph mode static quantization, custom modules are assumed to have quantized inputs and quantized outputs, with the exact dtypes derived from the associated QConfig (default quint8). Since custom modules are currently not handled through the reference model flow, their observer replacement logic are a little different from normal operators: ``` \# (1) Original model input -> custom_module -> output \# (2) Observed model (after prepare) input -> obs0 -> custom_module -> obs1 -> output \# (3) Quantized model (after convert) input -> quant -> quantized_custom_module -> dequant -> output ``` In the last step, input observers are replaced with "quantize" and output observers are replaced with "dequantize", in contrast to other non-custom-module patterns where observers are replaced with "quantize-dequantize" pairs instead. Note that, conceptually, the output observer `obs1` is really just a DeQuantStub, since no observation is actually needed. **Custom module LSTM:** The reason why custom module LSTM cannot be handled in the same way is because, unlike other custom modules, its inputs and outputs are nested tuples instead of single tensors. This is how the existing custom module code would try to handle LSTMs: ``` \# (1) Original model \# input format: (input, (hidden0, hidden1)) \# output format: (output, (hidden0, hidden1)) input -> lstm -> output hidden0 -/ \-> hidden0 hidden1 -/ \-> hidden1 \# (2) Observed model (after prepare) input -> obs0 -> lstm -> obs1 # fails hidden0 -/ # missing observer hidden1 -/ # missing observer ``` However, this fails today because 1) we assume there is only one input to the custom module, and so we never end up quantizing `hidden0` and `hidden1`, and 2) the output observer `obs1` is fed a tuple, which it does not understand how to handle. **Short-term fix:** This commit addresses the above by specifically handling the input and output structures used by custom module LSTM. For the inputs, we manually insert observers for `hidden0` and `hidden1` to ensure all input tensors are quantized. For the outputs, we split the tuple into its internal nodes, attach a DeQuantStub to each node, and recombine these DeQuantStubs according to the original structure. Finally, we must also reroute consumers of the original LSTM tuple (and its internal nodes, e.g. `lstm[0]`) to these DeQuantStubs: ``` \# (1) Original model input -> lstm -> output -> linear0 hidden0 -/ \-> hidden0 -> linear1 hidden1 -/ \-> hidden1 -> linear2 \# (2) Observed model (after prepare) input -> obs0 -> lstm -> output -> dqstub -> linear0 -> obs3 hidden0 -> obs1 -/ \-> hidden0 -> dqstub -> linear1 -> obs4 hidden1 -> obs2 -/ \-> hidden1 -> dqstub -> linear2 -> obs5 \# (3) Reference model (after convert) input -> quant -> qlstm -> output -> dequant -> linear0 -> quant -> dequant hidden0 -> quant -/ \-> hidden0 -> dequant -> linear1 -> quant -> dequant hidden1 -> quant -/ \-> hidden1 -> dequant -> linear2 -> quant -> dequant \# (4) Quantized model (after lowering) input -> quant -> qlstm -> output -> quantized_linear0 -> dequant hidden0 -> quant -/ \-> hidden0 -> quantized_linear1 -> dequant hidden1 -> quant -/ \-> hidden1 -> quantized_linear2 -> dequant ``` Note that we choose to insert DeQuantStubs here instead of observers because these will ultimately be replaced by "dequantize" nodes. This matches the general custom module behavior, where output observers are replaced only with "dequantize" nodes (as opposed to the normal "quantize-dequantize" pair), since custom module outputs are assumed to already be quantized. Using DeQuantStubs instead of observers also simplifies the "dequantize" insertion logic. In the future, we should use DeQuantStubs in place of output observers for custom modules in general. **Test plan:** python test/test_quantization.py TestQuantizeFx.test_static_lstm python test/test_quantization.py TestQuantizeFx.test_static_lstm_consume_tuple **Reviewers:** jerryzh168, vkuzo **Subscribers:** jerryzh168, vkuzo ghstack-source-id: adc35fcc921da69f14de316a57741cc97c7fd4a2 Pull Request resolved: #85068
**Summary:** This commit enables the custom module LSTM path for FX graph mode static quantization. This has the same flow as eager mode, which was already previously supported: ``` torch.nn.LSTM | (prepare_fx) v torch.ao.nn.quantizable.LSTM | (convert_fx) v torch.ao.nn.quantized.LSTM ``` The main reason why custom module LSTM is not supported in FX graph mode quantization today is because its inputs and outputs are nested tuples, and existing constructs such as observers, "quantize" nodes, and "dequantize" nodes do not understand how to handle complex structures. Note that the approach taken in this commit is only intended to be a short-term solution highly tailored to the input and output formats of custom module LSTM. In the future, for the longer-term solution, we should design a more general QConfig that allows users to specify complex input and output formats, and enable FX graph mode quantization to understand arbitrary nested structures and automatically infer how to transform the graph accordingly. **Context:** Today, in FX graph mode static quantization, custom modules are assumed to have quantized inputs and quantized outputs, with the exact dtypes derived from the associated QConfig (default quint8). Since custom modules are currently not handled through the reference model flow, their observer replacement logic are a little different from normal operators: ``` # (1) Original model input -> custom_module -> output # (2) Observed model (after prepare) input -> obs0 -> custom_module -> obs1 -> output # (3) Quantized model (after convert) input -> quant -> quantized_custom_module -> dequant -> output ``` In the last step, input observers are replaced with "quantize" and output observers are replaced with "dequantize", in contrast to other non-custom-module patterns where observers are replaced with "quantize-dequantize" pairs instead. Note that, conceptually, the output observer `obs1` is really just a DeQuantStub, since no observation is actually needed. **Custom module LSTM:** The reason why custom module LSTM cannot be handled in the same way is because, unlike other custom modules, its inputs and outputs are nested tuples instead of single tensors. This is how the existing custom module code would try to handle LSTMs: ``` # (1) Original model # input format: (input, (hidden0, hidden1)) # output format: (output, (hidden0, hidden1)) input -> lstm -> output hidden0 -/ \-> hidden0 hidden1 -/ \-> hidden1 # (2) Observed model (after prepare) input -> obs0 -> lstm -> obs1 # fails hidden0 -/ # missing observer hidden1 -/ # missing observer ``` However, this fails today because 1) we assume there is only one input to the custom module, and so we never end up quantizing `hidden0` and `hidden1`, and 2) the output observer `obs1` is fed a tuple, which it does not understand how to handle. **Short-term fix:** This commit addresses the above by specifically handling the input and output structures used by custom module LSTM. For the inputs, we manually insert observers for `hidden0` and `hidden1` to ensure all input tensors are quantized. For the outputs, we split the tuple into its internal nodes, attach a DeQuantStub to each node, and recombine these DeQuantStubs according to the original structure. Finally, we must also reroute consumers of the original LSTM tuple (and its internal nodes, e.g. `lstm[0]`) to these DeQuantStubs: ``` # (1) Original model input -> lstm -> output -> linear0 hidden0 -/ \-> hidden0 -> linear1 hidden1 -/ \-> hidden1 -> linear2 # (2) Observed model (after prepare) input -> obs0 -> lstm -> output -> dqstub -> linear0 -> obs3 hidden0 -> obs1 -/ \-> hidden0 -> dqstub -> linear1 -> obs4 hidden1 -> obs2 -/ \-> hidden1 -> dqstub -> linear2 -> obs5 # (3) Reference model (after convert) input -> quant -> qlstm -> output -> dequant -> linear0 -> quant -> dequant hidden0 -> quant -/ \-> hidden0 -> dequant -> linear1 -> quant -> dequant hidden1 -> quant -/ \-> hidden1 -> dequant -> linear2 -> quant -> dequant # (4) Quantized model (after lowering) input -> quant -> qlstm -> output -> quantized_linear0 -> dequant hidden0 -> quant -/ \-> hidden0 -> quantized_linear1 -> dequant hidden1 -> quant -/ \-> hidden1 -> quantized_linear2 -> dequant ``` Note that we choose to insert DeQuantStubs here instead of observers because these will ultimately be replaced by "dequantize" nodes. This matches the general custom module behavior, where output observers are replaced only with "dequantize" nodes (as opposed to the normal "quantize-dequantize" pair), since custom module outputs are assumed to already be quantized. Using DeQuantStubs instead of observers also simplifies the "dequantize" insertion logic. In the future, we should use DeQuantStubs in place of output observers for custom modules in general. **Test plan:** python test/test_quantization.py TestQuantizeFx.test_static_lstm python test/test_quantization.py TestQuantizeFx.test_static_lstm_consume_tuple **Reviewers:** jerryzh168, vkuzo **Subscribers:** jerryzh168, vkuzo [ghstack-poisoned]
**Summary:** This commit enables the custom module LSTM path for FX graph mode static quantization. This has the same flow as eager mode, which was already previously supported: ``` torch.nn.LSTM | (prepare_fx) v torch.ao.nn.quantizable.LSTM | (convert_fx) v torch.ao.nn.quantized.LSTM ``` The main reason why custom module LSTM is not supported in FX graph mode quantization today is because its inputs and outputs are nested tuples, and existing constructs such as observers, "quantize" nodes, and "dequantize" nodes do not understand how to handle complex structures. Note that the approach taken in this commit is only intended to be a short-term solution highly tailored to the input and output formats of custom module LSTM. In the future, for the longer-term solution, we should design a more general QConfig that allows users to specify complex input and output formats, and enable FX graph mode quantization to understand arbitrary nested structures and automatically infer how to transform the graph accordingly. **Context:** Today, in FX graph mode static quantization, custom modules are assumed to have quantized inputs and quantized outputs, with the exact dtypes derived from the associated QConfig (default quint8). Since custom modules are currently not handled through the reference model flow, their observer replacement logic are a little different from normal operators: ``` \# (1) Original model input -> custom_module -> output \# (2) Observed model (after prepare) input -> obs0 -> custom_module -> obs1 -> output \# (3) Quantized model (after convert) input -> quant -> quantized_custom_module -> dequant -> output ``` In the last step, input observers are replaced with "quantize" and output observers are replaced with "dequantize", in contrast to other non-custom-module patterns where observers are replaced with "quantize-dequantize" pairs instead. Note that, conceptually, the output observer `obs1` is really just a DeQuantStub, since no observation is actually needed. **Custom module LSTM:** The reason why custom module LSTM cannot be handled in the same way is because, unlike other custom modules, its inputs and outputs are nested tuples instead of single tensors. This is how the existing custom module code would try to handle LSTMs: ``` \# (1) Original model \# input format: (input, (hidden0, hidden1)) \# output format: (output, (hidden0, hidden1)) input -> lstm -> output hidden0 -/ \-> hidden0 hidden1 -/ \-> hidden1 \# (2) Observed model (after prepare) input -> obs0 -> lstm -> obs1 # fails hidden0 -/ # missing observer hidden1 -/ # missing observer ``` However, this fails today because 1) we assume there is only one input to the custom module, and so we never end up quantizing `hidden0` and `hidden1`, and 2) the output observer `obs1` is fed a tuple, which it does not understand how to handle. **Short-term fix:** This commit addresses the above by specifically handling the input and output structures used by custom module LSTM. For the inputs, we manually insert observers for `hidden0` and `hidden1` to ensure all input tensors are quantized. For the outputs, we split the tuple into its internal nodes, attach a DeQuantStub to each node, and recombine these DeQuantStubs according to the original structure. Finally, we must also reroute consumers of the original LSTM tuple (and its internal nodes, e.g. `lstm[0]`) to these DeQuantStubs: ``` \# (1) Original model input -> lstm -> output -> linear0 hidden0 -/ \-> hidden0 -> linear1 hidden1 -/ \-> hidden1 -> linear2 \# (2) Observed model (after prepare) input -> obs0 -> lstm -> output -> dqstub -> linear0 -> obs3 hidden0 -> obs1 -/ \-> hidden0 -> dqstub -> linear1 -> obs4 hidden1 -> obs2 -/ \-> hidden1 -> dqstub -> linear2 -> obs5 \# (3) Reference model (after convert) input -> quant -> qlstm -> output -> dequant -> linear0 -> quant -> dequant hidden0 -> quant -/ \-> hidden0 -> dequant -> linear1 -> quant -> dequant hidden1 -> quant -/ \-> hidden1 -> dequant -> linear2 -> quant -> dequant \# (4) Quantized model (after lowering) input -> quant -> qlstm -> output -> quantized_linear0 -> dequant hidden0 -> quant -/ \-> hidden0 -> quantized_linear1 -> dequant hidden1 -> quant -/ \-> hidden1 -> quantized_linear2 -> dequant ``` Note that we choose to insert DeQuantStubs here instead of observers because these will ultimately be replaced by "dequantize" nodes. This matches the general custom module behavior, where output observers are replaced only with "dequantize" nodes (as opposed to the normal "quantize-dequantize" pair), since custom module outputs are assumed to already be quantized. Using DeQuantStubs instead of observers also simplifies the "dequantize" insertion logic. In the future, we should use DeQuantStubs in place of output observers for custom modules in general. **Test plan:** python test/test_quantization.py TestQuantizeFx.test_static_lstm python test/test_quantization.py TestQuantizeFx.test_static_lstm_consume_tuple **Reviewers:** jerryzh168, vkuzo **Subscribers:** jerryzh168, vkuzo ghstack-source-id: 39c4989fa6425c59d23c3704d3ff0becb75069ac Pull Request resolved: #85068
Sure. In prepare, we make two special cases:
As for convert:
|
**Summary:** This commit enables the custom module LSTM path for FX graph mode static quantization. This has the same flow as eager mode, which was already previously supported: ``` torch.nn.LSTM | (prepare_fx) v torch.ao.nn.quantizable.LSTM | (convert_fx) v torch.ao.nn.quantized.LSTM ``` The main reason why custom module LSTM is not supported in FX graph mode quantization today is because its inputs and outputs are nested tuples, and existing constructs such as observers, "quantize" nodes, and "dequantize" nodes do not understand how to handle complex structures. Note that the approach taken in this commit is only intended to be a short-term solution highly tailored to the input and output formats of custom module LSTM. In the future, for the longer-term solution, we should design a more general QConfig that allows users to specify complex input and output formats, and enable FX graph mode quantization to understand arbitrary nested structures and automatically infer how to transform the graph accordingly. **Context:** Today, in FX graph mode static quantization, custom modules are assumed to have quantized inputs and quantized outputs, with the exact dtypes derived from the associated QConfig (default quint8). Since custom modules are currently not handled through the reference model flow, their observer replacement logic are a little different from normal operators: ``` # (1) Original model input -> custom_module -> output # (2) Observed model (after prepare) input -> obs0 -> custom_module -> obs1 -> output # (3) Quantized model (after convert) input -> quant -> quantized_custom_module -> dequant -> output ``` In the last step, input observers are replaced with "quantize" and output observers are replaced with "dequantize", in contrast to other non-custom-module patterns where observers are replaced with "quantize-dequantize" pairs instead. Note that, conceptually, the output observer `obs1` is really just a DeQuantStub, since no observation is actually needed. **Custom module LSTM:** The reason why custom module LSTM cannot be handled in the same way is because, unlike other custom modules, its inputs and outputs are nested tuples instead of single tensors. This is how the existing custom module code would try to handle LSTMs: ``` # (1) Original model # input format: (input, (hidden0, hidden1)) # output format: (output, (hidden0, hidden1)) input -> lstm -> output hidden0 -/ \-> hidden0 hidden1 -/ \-> hidden1 # (2) Observed model (after prepare) input -> obs0 -> lstm -> obs1 # fails hidden0 -/ # missing observer hidden1 -/ # missing observer ``` However, this fails today because 1) we assume there is only one input to the custom module, and so we never end up quantizing `hidden0` and `hidden1`, and 2) the output observer `obs1` is fed a tuple, which it does not understand how to handle. **Short-term fix:** This commit addresses the above by specifically handling the input and output structures used by custom module LSTM. For the inputs, we manually insert observers for `hidden0` and `hidden1` to ensure all input tensors are quantized. For the outputs, we split the tuple into its internal nodes, attach a DeQuantStub to each node, and recombine these DeQuantStubs according to the original structure. Finally, we must also reroute consumers of the original LSTM tuple (and its internal nodes, e.g. `lstm[0]`) to these DeQuantStubs: ``` # (1) Original model input -> lstm -> output -> linear0 hidden0 -/ \-> hidden0 -> linear1 hidden1 -/ \-> hidden1 -> linear2 # (2) Observed model (after prepare) input -> obs0 -> lstm -> output -> dqstub -> linear0 -> obs3 hidden0 -> obs1 -/ \-> hidden0 -> dqstub -> linear1 -> obs4 hidden1 -> obs2 -/ \-> hidden1 -> dqstub -> linear2 -> obs5 # (3) Reference model (after convert) input -> quant -> qlstm -> output -> dequant -> linear0 -> quant -> dequant hidden0 -> quant -/ \-> hidden0 -> dequant -> linear1 -> quant -> dequant hidden1 -> quant -/ \-> hidden1 -> dequant -> linear2 -> quant -> dequant # (4) Quantized model (after lowering) input -> quant -> qlstm -> output -> quantized_linear0 -> dequant hidden0 -> quant -/ \-> hidden0 -> quantized_linear1 -> dequant hidden1 -> quant -/ \-> hidden1 -> quantized_linear2 -> dequant ``` Note that we choose to insert DeQuantStubs here instead of observers because these will ultimately be replaced by "dequantize" nodes. This matches the general custom module behavior, where output observers are replaced only with "dequantize" nodes (as opposed to the normal "quantize-dequantize" pair), since custom module outputs are assumed to already be quantized. Using DeQuantStubs instead of observers also simplifies the "dequantize" insertion logic. In the future, we should use DeQuantStubs in place of output observers for custom modules in general. **Test plan:** python test/test_quantization.py TestQuantizeFx.test_static_lstm python test/test_quantization.py TestQuantizeFx.test_static_lstm_consume_tuple **Reviewers:** jerryzh168, vkuzo **Subscribers:** jerryzh168, vkuzo [ghstack-poisoned]
**Summary:** This commit enables the custom module LSTM path for FX graph mode static quantization. This has the same flow as eager mode, which was already previously supported: ``` torch.nn.LSTM | (prepare_fx) v torch.ao.nn.quantizable.LSTM | (convert_fx) v torch.ao.nn.quantized.LSTM ``` The main reason why custom module LSTM is not supported in FX graph mode quantization today is because its inputs and outputs are nested tuples, and existing constructs such as observers, "quantize" nodes, and "dequantize" nodes do not understand how to handle complex structures. Note that the approach taken in this commit is only intended to be a short-term solution highly tailored to the input and output formats of custom module LSTM. In the future, for the longer-term solution, we should design a more general QConfig that allows users to specify complex input and output formats, and enable FX graph mode quantization to understand arbitrary nested structures and automatically infer how to transform the graph accordingly. **Context:** Today, in FX graph mode static quantization, custom modules are assumed to have quantized inputs and quantized outputs, with the exact dtypes derived from the associated QConfig (default quint8). Since custom modules are currently not handled through the reference model flow, their observer replacement logic are a little different from normal operators: ``` \# (1) Original model input -> custom_module -> output \# (2) Observed model (after prepare) input -> obs0 -> custom_module -> obs1 -> output \# (3) Quantized model (after convert) input -> quant -> quantized_custom_module -> dequant -> output ``` In the last step, input observers are replaced with "quantize" and output observers are replaced with "dequantize", in contrast to other non-custom-module patterns where observers are replaced with "quantize-dequantize" pairs instead. Note that, conceptually, the output observer `obs1` is really just a DeQuantStub, since no observation is actually needed. **Custom module LSTM:** The reason why custom module LSTM cannot be handled in the same way is because, unlike other custom modules, its inputs and outputs are nested tuples instead of single tensors. This is how the existing custom module code would try to handle LSTMs: ``` \# (1) Original model \# input format: (input, (hidden0, hidden1)) \# output format: (output, (hidden0, hidden1)) input -> lstm -> output hidden0 -/ \-> hidden0 hidden1 -/ \-> hidden1 \# (2) Observed model (after prepare) input -> obs0 -> lstm -> obs1 # fails hidden0 -/ # missing observer hidden1 -/ # missing observer ``` However, this fails today because 1) we assume there is only one input to the custom module, and so we never end up quantizing `hidden0` and `hidden1`, and 2) the output observer `obs1` is fed a tuple, which it does not understand how to handle. **Short-term fix:** This commit addresses the above by specifically handling the input and output structures used by custom module LSTM. For the inputs, we manually insert observers for `hidden0` and `hidden1` to ensure all input tensors are quantized. For the outputs, we split the tuple into its internal nodes, attach a DeQuantStub to each node, and recombine these DeQuantStubs according to the original structure. Finally, we must also reroute consumers of the original LSTM tuple (and its internal nodes, e.g. `lstm[0]`) to these DeQuantStubs: ``` \# (1) Original model input -> lstm -> output -> linear0 hidden0 -/ \-> hidden0 -> linear1 hidden1 -/ \-> hidden1 -> linear2 \# (2) Observed model (after prepare) input -> obs0 -> lstm -> output -> dqstub -> linear0 -> obs3 hidden0 -> obs1 -/ \-> hidden0 -> dqstub -> linear1 -> obs4 hidden1 -> obs2 -/ \-> hidden1 -> dqstub -> linear2 -> obs5 \# (3) Reference model (after convert) input -> quant -> qlstm -> output -> dequant -> linear0 -> quant -> dequant hidden0 -> quant -/ \-> hidden0 -> dequant -> linear1 -> quant -> dequant hidden1 -> quant -/ \-> hidden1 -> dequant -> linear2 -> quant -> dequant \# (4) Quantized model (after lowering) input -> quant -> qlstm -> output -> quantized_linear0 -> dequant hidden0 -> quant -/ \-> hidden0 -> quantized_linear1 -> dequant hidden1 -> quant -/ \-> hidden1 -> quantized_linear2 -> dequant ``` Note that we choose to insert DeQuantStubs here instead of observers because these will ultimately be replaced by "dequantize" nodes. This matches the general custom module behavior, where output observers are replaced only with "dequantize" nodes (as opposed to the normal "quantize-dequantize" pair), since custom module outputs are assumed to already be quantized. Using DeQuantStubs instead of observers also simplifies the "dequantize" insertion logic. In the future, we should use DeQuantStubs in place of output observers for custom modules in general. **Test plan:** python test/test_quantization.py TestQuantizeFx.test_static_lstm python test/test_quantization.py TestQuantizeFx.test_static_lstm_consume_tuple python test/test_quantization.py TestQuantizeFx.test_reroute_tuple_getitem_patterns **Reviewers:** jerryzh168, vkuzo **Subscribers:** jerryzh168, vkuzo ghstack-source-id: d5a67e565fca7f62d2d22c16b6ec662e7079351c Pull Request resolved: #85068
return matched_node | ||
return None | ||
|
||
def _reroute_tuple_getitem_pattern(graph: Graph): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @SherlockNoMad do we want to support something similar in some default fx optimization pass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good, thanks for the hard work!
thanks, please add this to the Summary as well so that we can use this to fix the hacks in the future |
**Summary:** This commit enables the custom module LSTM path for FX graph mode static quantization. This has the same flow as eager mode, which was already previously supported: ``` torch.nn.LSTM | (prepare_fx) v torch.ao.nn.quantizable.LSTM | (convert_fx) v torch.ao.nn.quantized.LSTM ``` The main reason why custom module LSTM is not supported in FX graph mode quantization today is because its inputs and outputs are nested tuples, and existing constructs such as observers, "quantize" nodes, and "dequantize" nodes do not understand how to handle complex structures. Note that the approach taken in this commit is only intended to be a short-term solution highly tailored to the input and output formats of custom module LSTM. In the future, for the longer-term solution, we should design a more general QConfig that allows users to specify complex input and output formats, and enable FX graph mode quantization to understand arbitrary nested structures and automatically infer how to transform the graph accordingly. **Context:** Today, in FX graph mode static quantization, custom modules are assumed to have quantized inputs and quantized outputs, with the exact dtypes derived from the associated QConfig (default quint8). Since custom modules are currently not handled through the reference model flow, their observer replacement logic are a little different from normal operators: ``` # (1) Original model input -> custom_module -> output # (2) Observed model (after prepare) input -> obs0 -> custom_module -> obs1 -> output # (3) Quantized model (after convert) input -> quant -> quantized_custom_module -> dequant -> output ``` In the last step, input observers are replaced with "quantize" and output observers are replaced with "dequantize", in contrast to other non-custom-module patterns where observers are replaced with "quantize-dequantize" pairs instead. Note that, conceptually, the output observer `obs1` is really just a DeQuantStub, since no observation is actually needed. **Custom module LSTM:** The reason why custom module LSTM cannot be handled in the same way is because, unlike other custom modules, its inputs and outputs are nested tuples instead of single tensors. This is how the existing custom module code would try to handle LSTMs: ``` # (1) Original model # input format: (input, (hidden0, hidden1)) # output format: (output, (hidden0, hidden1)) input -> lstm -> output hidden0 -/ \-> hidden0 hidden1 -/ \-> hidden1 # (2) Observed model (after prepare) input -> obs0 -> lstm -> obs1 # fails hidden0 -/ # missing observer hidden1 -/ # missing observer ``` However, this fails today because 1) we assume there is only one input to the custom module, and so we never end up quantizing `hidden0` and `hidden1`, and 2) the output observer `obs1` is fed a tuple, which it does not understand how to handle. **Short-term fix:** This commit addresses the above by specifically handling the input and output structures used by custom module LSTM. For the inputs, we manually insert observers for `hidden0` and `hidden1` to ensure all input tensors are quantized. For the outputs, we split the tuple into its internal nodes, attach a DeQuantStub to each node, and recombine these DeQuantStubs according to the original structure. Finally, we must also reroute consumers of the original LSTM tuple (and its internal nodes, e.g. `lstm[0]`) to these DeQuantStubs: ``` # (1) Original model input -> lstm -> output -> linear0 hidden0 -/ \-> hidden0 -> linear1 hidden1 -/ \-> hidden1 -> linear2 # (2) Observed model (after prepare) input -> obs0 -> lstm -> output -> dqstub -> linear0 -> obs3 hidden0 -> obs1 -/ \-> hidden0 -> dqstub -> linear1 -> obs4 hidden1 -> obs2 -/ \-> hidden1 -> dqstub -> linear2 -> obs5 # (3) Reference model (after convert) input -> quant -> qlstm -> output -> dequant -> linear0 -> quant -> dequant hidden0 -> quant -/ \-> hidden0 -> dequant -> linear1 -> quant -> dequant hidden1 -> quant -/ \-> hidden1 -> dequant -> linear2 -> quant -> dequant # (4) Quantized model (after lowering) input -> quant -> qlstm -> output -> quantized_linear0 -> dequant hidden0 -> quant -/ \-> hidden0 -> quantized_linear1 -> dequant hidden1 -> quant -/ \-> hidden1 -> quantized_linear2 -> dequant ``` Note that we choose to insert DeQuantStubs here instead of observers because these will ultimately be replaced by "dequantize" nodes. This matches the general custom module behavior, where output observers are replaced only with "dequantize" nodes (as opposed to the normal "quantize-dequantize" pair), since custom module outputs are assumed to already be quantized. Using DeQuantStubs instead of observers also simplifies the "dequantize" insertion logic. In the future, we should use DeQuantStubs in place of output observers for custom modules in general. **Test plan:** python test/test_quantization.py TestQuantizeFx.test_static_lstm python test/test_quantization.py TestQuantizeFx.test_static_lstm_consume_tuple **Reviewers:** jerryzh168, vkuzo **Subscribers:** jerryzh168, vkuzo [ghstack-poisoned]
**Summary:** This commit enables the custom module LSTM path for FX graph mode static quantization. This has the same flow as eager mode, which was already previously supported: ``` torch.nn.LSTM | (prepare_fx) v torch.ao.nn.quantizable.LSTM | (convert_fx) v torch.ao.nn.quantized.LSTM ``` The main reason why custom module LSTM is not supported in FX graph mode quantization today is because its inputs and outputs are nested tuples, and existing constructs such as observers, "quantize" nodes, and "dequantize" nodes do not understand how to handle complex structures. Note that the approach taken in this commit is only intended to be a short-term solution highly tailored to the input and output formats of custom module LSTM. In the future, for the longer-term solution, we should design a more general QConfig that allows users to specify complex input and output formats, and enable FX graph mode quantization to understand arbitrary nested structures and automatically infer how to transform the graph accordingly. **Context:** Today, in FX graph mode static quantization, custom modules are assumed to have quantized inputs and quantized outputs, with the exact dtypes derived from the associated QConfig (default quint8). Since custom modules are currently not handled through the reference model flow, their observer replacement logic are a little different from normal operators: ``` \# (1) Original model input -> custom_module -> output \# (2) Observed model (after prepare) input -> obs0 -> custom_module -> obs1 -> output \# (3) Quantized model (after convert) input -> quant -> quantized_custom_module -> dequant -> output ``` In the last step, input observers are replaced with "quantize" and output observers are replaced with "dequantize", in contrast to other non-custom-module patterns where observers are replaced with "quantize-dequantize" pairs instead. Note that, conceptually, the output observer `obs1` is really just a DeQuantStub, since no observation is actually needed. **Custom module LSTM:** The reason why custom module LSTM cannot be handled in the same way is because, unlike other custom modules, its inputs and outputs are nested tuples instead of single tensors. This is how the existing custom module code would try to handle LSTMs: ``` \# (1) Original model \# input format: (input, (hidden0, hidden1)) \# output format: (output, (hidden0, hidden1)) input -> lstm -> output hidden0 -/ \-> hidden0 hidden1 -/ \-> hidden1 \# (2) Observed model (after prepare) input -> obs0 -> lstm -> obs1 # fails hidden0 -/ # missing observer hidden1 -/ # missing observer ``` However, this fails today because 1) we assume there is only one input to the custom module, and so we never end up quantizing `hidden0` and `hidden1`, and 2) the output observer `obs1` is fed a tuple, which it does not understand how to handle. **Short-term fix:** This commit addresses the above by specifically handling the input and output structures used by custom module LSTM. For the inputs, we manually insert observers for `hidden0` and `hidden1` to ensure all input tensors are quantized. For the outputs, we split the tuple into its internal nodes, attach a DeQuantStub to each node, and recombine these DeQuantStubs according to the original structure. Finally, we must also reroute consumers of the original LSTM tuple (and its internal nodes, e.g. `lstm[0]`) to these DeQuantStubs: ``` \# (1) Original model input -> lstm -> output -> linear0 hidden0 -/ \-> hidden0 -> linear1 hidden1 -/ \-> hidden1 -> linear2 \# (2) Observed model (after prepare) input -> obs0 -> lstm -> output -> dqstub -> linear0 -> obs3 hidden0 -> obs1 -/ \-> hidden0 -> dqstub -> linear1 -> obs4 hidden1 -> obs2 -/ \-> hidden1 -> dqstub -> linear2 -> obs5 \# (3) Reference model (after convert) input -> quant -> qlstm -> output -> dequant -> linear0 -> quant -> dequant hidden0 -> quant -/ \-> hidden0 -> dequant -> linear1 -> quant -> dequant hidden1 -> quant -/ \-> hidden1 -> dequant -> linear2 -> quant -> dequant \# (4) Quantized model (after lowering) input -> quant -> qlstm -> output -> quantized_linear0 -> dequant hidden0 -> quant -/ \-> hidden0 -> quantized_linear1 -> dequant hidden1 -> quant -/ \-> hidden1 -> quantized_linear2 -> dequant ``` Note that we choose to insert DeQuantStubs here instead of observers because these will ultimately be replaced by "dequantize" nodes. This matches the general custom module behavior, where output observers are replaced only with "dequantize" nodes (as opposed to the normal "quantize-dequantize" pair), since custom module outputs are assumed to already be quantized. Using DeQuantStubs instead of observers also simplifies the "dequantize" insertion logic. In the future, we should use DeQuantStubs in place of output observers for custom modules in general. **Implementation:** In prepare, we make two special cases: (1) When inserting output observers, we instead break the LSTM output tuple into its internal nodes, insert a DeQuantStub after each node, and recombine the tuple according to the original format (in `_insert_dequant_stubs_for_custom_module_lstm_output`). This diverges from, but is conceptually the same as, how we insert "special" output observers that will be converted to "dequantize" nodes (during convert) today for other custom modules. In the future, we should just insert DeQuantStubs instead of these "special" observers for custom modules in general. (2) When inserting input observers for a node, we check whether it is a consumer of LSTM by traversing up the subgraph created in (1) (see `_maybe_get_custom_module_lstm_from_node_arg`). Once we identified what the previous node is, we reuse the existing code in `maybe_insert_input_observer_for_arg_or_kwarg` to decide whether or not to insert the input observer based on the output dtype of the previous node. As for convert: (1) When converting the custom module, we change `quantize - dequantize - custom_module` to `quantize - custom_module` for all inputs of custom module LSTM, including the hidden inputs (in `convert_custom_module`). This is consistent with how we do the same for custom module inputs today, except it handles more than one input. (2) We replace DeQuantStubs with "dequantize" nodes. This is not really the special case as we plan to do this for all custom modules in the future. **Test plan:** python test/test_quantization.py TestQuantizeFx.test_static_lstm python test/test_quantization.py TestQuantizeFx.test_static_lstm_consume_tuple python test/test_quantization.py TestQuantizeFx.test_reroute_tuple_getitem_patterns **Reviewers:** jerryzh168, vkuzo **Subscribers:** jerryzh168, vkuzo ghstack-source-id: 06f45764ec461dcd443617cd43964b352fa490d2 Pull Request resolved: #85068
Ok, I'm merging this. Thanks for all the feedback! |
@pytorchbot merge |
@pytorchbot successfully started a merge job. Check the current status here. |
Hey @andrewor14. |
**Summary:** This commit enables the custom module LSTM path for FX graph mode static quantization. This has the same flow as eager mode, which was already previously supported: ``` torch.nn.LSTM | (prepare_fx) v torch.ao.nn.quantizable.LSTM | (convert_fx) v torch.ao.nn.quantized.LSTM ``` The main reason why custom module LSTM is not supported in FX graph mode quantization today is because its inputs and outputs are nested tuples, and existing constructs such as observers, "quantize" nodes, and "dequantize" nodes do not understand how to handle complex structures. Note that the approach taken in this commit is only intended to be a short-term solution highly tailored to the input and output formats of custom module LSTM. In the future, for the longer-term solution, we should design a more general QConfig that allows users to specify complex input and output formats, and enable FX graph mode quantization to understand arbitrary nested structures and automatically infer how to transform the graph accordingly. **Context:** Today, in FX graph mode static quantization, custom modules are assumed to have quantized inputs and quantized outputs, with the exact dtypes derived from the associated QConfig (default quint8). Since custom modules are currently not handled through the reference model flow, their observer replacement logic are a little different from normal operators: ``` # (1) Original model input -> custom_module -> output # (2) Observed model (after prepare) input -> obs0 -> custom_module -> obs1 -> output # (3) Quantized model (after convert) input -> quant -> quantized_custom_module -> dequant -> output ``` In the last step, input observers are replaced with "quantize" and output observers are replaced with "dequantize", in contrast to other non-custom-module patterns where observers are replaced with "quantize-dequantize" pairs instead. Note that, conceptually, the output observer `obs1` is really just a DeQuantStub, since no observation is actually needed. **Custom module LSTM:** The reason why custom module LSTM cannot be handled in the same way is because, unlike other custom modules, its inputs and outputs are nested tuples instead of single tensors. This is how the existing custom module code would try to handle LSTMs: ``` # (1) Original model # input format: (input, (hidden0, hidden1)) # output format: (output, (hidden0, hidden1)) input -> lstm -> output hidden0 -/ \-> hidden0 hidden1 -/ \-> hidden1 # (2) Observed model (after prepare) input -> obs0 -> lstm -> obs1 # fails hidden0 -/ # missing observer hidden1 -/ # missing observer ``` However, this fails today because 1) we assume there is only one input to the custom module, and so we never end up quantizing `hidden0` and `hidden1`, and 2) the output observer `obs1` is fed a tuple, which it does not understand how to handle. **Short-term fix:** This commit addresses the above by specifically handling the input and output structures used by custom module LSTM. For the inputs, we manually insert observers for `hidden0` and `hidden1` to ensure all input tensors are quantized. For the outputs, we split the tuple into its internal nodes, attach a DeQuantStub to each node, and recombine these DeQuantStubs according to the original structure. Finally, we must also reroute consumers of the original LSTM tuple (and its internal nodes, e.g. `lstm[0]`) to these DeQuantStubs: ``` # (1) Original model input -> lstm -> output -> linear0 hidden0 -/ \-> hidden0 -> linear1 hidden1 -/ \-> hidden1 -> linear2 # (2) Observed model (after prepare) input -> obs0 -> lstm -> output -> dqstub -> linear0 -> obs3 hidden0 -> obs1 -/ \-> hidden0 -> dqstub -> linear1 -> obs4 hidden1 -> obs2 -/ \-> hidden1 -> dqstub -> linear2 -> obs5 # (3) Reference model (after convert) input -> quant -> qlstm -> output -> dequant -> linear0 -> quant -> dequant hidden0 -> quant -/ \-> hidden0 -> dequant -> linear1 -> quant -> dequant hidden1 -> quant -/ \-> hidden1 -> dequant -> linear2 -> quant -> dequant # (4) Quantized model (after lowering) input -> quant -> qlstm -> output -> quantized_linear0 -> dequant hidden0 -> quant -/ \-> hidden0 -> quantized_linear1 -> dequant hidden1 -> quant -/ \-> hidden1 -> quantized_linear2 -> dequant ``` Note that we choose to insert DeQuantStubs here instead of observers because these will ultimately be replaced by "dequantize" nodes. This matches the general custom module behavior, where output observers are replaced only with "dequantize" nodes (as opposed to the normal "quantize-dequantize" pair), since custom module outputs are assumed to already be quantized. Using DeQuantStubs instead of observers also simplifies the "dequantize" insertion logic. In the future, we should use DeQuantStubs in place of output observers for custom modules in general. **Test plan:** python test/test_quantization.py TestQuantizeFx.test_static_lstm python test/test_quantization.py TestQuantizeFx.test_static_lstm_consume_tuple **Reviewers:** jerryzh168, vkuzo **Subscribers:** jerryzh168, vkuzo Pull Request resolved: #85068 Approved by: https://github.com/jerryzh168
**Summary:** This commit enables the custom module LSTM path for FX graph mode static quantization. This has the same flow as eager mode, which was already previously supported: ``` torch.nn.LSTM | (prepare_fx) v torch.ao.nn.quantizable.LSTM | (convert_fx) v torch.ao.nn.quantized.LSTM ``` The main reason why custom module LSTM is not supported in FX graph mode quantization today is because its inputs and outputs are nested tuples, and existing constructs such as observers, "quantize" nodes, and "dequantize" nodes do not understand how to handle complex structures. Note that the approach taken in this commit is only intended to be a short-term solution highly tailored to the input and output formats of custom module LSTM. In the future, for the longer-term solution, we should design a more general QConfig that allows users to specify complex input and output formats, and enable FX graph mode quantization to understand arbitrary nested structures and automatically infer how to transform the graph accordingly. **Context:** Today, in FX graph mode static quantization, custom modules are assumed to have quantized inputs and quantized outputs, with the exact dtypes derived from the associated QConfig (default quint8). Since custom modules are currently not handled through the reference model flow, their observer replacement logic are a little different from normal operators: ``` # (1) Original model input -> custom_module -> output # (2) Observed model (after prepare) input -> obs0 -> custom_module -> obs1 -> output # (3) Quantized model (after convert) input -> quant -> quantized_custom_module -> dequant -> output ``` In the last step, input observers are replaced with "quantize" and output observers are replaced with "dequantize", in contrast to other non-custom-module patterns where observers are replaced with "quantize-dequantize" pairs instead. Note that, conceptually, the output observer `obs1` is really just a DeQuantStub, since no observation is actually needed. **Custom module LSTM:** The reason why custom module LSTM cannot be handled in the same way is because, unlike other custom modules, its inputs and outputs are nested tuples instead of single tensors. This is how the existing custom module code would try to handle LSTMs: ``` # (1) Original model # input format: (input, (hidden0, hidden1)) # output format: (output, (hidden0, hidden1)) input -> lstm -> output hidden0 -/ \-> hidden0 hidden1 -/ \-> hidden1 # (2) Observed model (after prepare) input -> obs0 -> lstm -> obs1 # fails hidden0 -/ # missing observer hidden1 -/ # missing observer ``` However, this fails today because 1) we assume there is only one input to the custom module, and so we never end up quantizing `hidden0` and `hidden1`, and 2) the output observer `obs1` is fed a tuple, which it does not understand how to handle. **Short-term fix:** This commit addresses the above by specifically handling the input and output structures used by custom module LSTM. For the inputs, we manually insert observers for `hidden0` and `hidden1` to ensure all input tensors are quantized. For the outputs, we split the tuple into its internal nodes, attach a DeQuantStub to each node, and recombine these DeQuantStubs according to the original structure. Finally, we must also reroute consumers of the original LSTM tuple (and its internal nodes, e.g. `lstm[0]`) to these DeQuantStubs: ``` # (1) Original model input -> lstm -> output -> linear0 hidden0 -/ \-> hidden0 -> linear1 hidden1 -/ \-> hidden1 -> linear2 # (2) Observed model (after prepare) input -> obs0 -> lstm -> output -> dqstub -> linear0 -> obs3 hidden0 -> obs1 -/ \-> hidden0 -> dqstub -> linear1 -> obs4 hidden1 -> obs2 -/ \-> hidden1 -> dqstub -> linear2 -> obs5 # (3) Reference model (after convert) input -> quant -> qlstm -> output -> dequant -> linear0 -> quant -> dequant hidden0 -> quant -/ \-> hidden0 -> dequant -> linear1 -> quant -> dequant hidden1 -> quant -/ \-> hidden1 -> dequant -> linear2 -> quant -> dequant # (4) Quantized model (after lowering) input -> quant -> qlstm -> output -> quantized_linear0 -> dequant hidden0 -> quant -/ \-> hidden0 -> quantized_linear1 -> dequant hidden1 -> quant -/ \-> hidden1 -> quantized_linear2 -> dequant ``` Note that we choose to insert DeQuantStubs here instead of observers because these will ultimately be replaced by "dequantize" nodes. This matches the general custom module behavior, where output observers are replaced only with "dequantize" nodes (as opposed to the normal "quantize-dequantize" pair), since custom module outputs are assumed to already be quantized. Using DeQuantStubs instead of observers also simplifies the "dequantize" insertion logic. In the future, we should use DeQuantStubs in place of output observers for custom modules in general. **Test plan:** python test/test_quantization.py TestQuantizeFx.test_static_lstm python test/test_quantization.py TestQuantizeFx.test_static_lstm_consume_tuple **Reviewers:** jerryzh168, vkuzo **Subscribers:** jerryzh168, vkuzo Pull Request resolved: pytorch#85068 Approved by: https://github.com/jerryzh168
Stack from ghstack (oldest at bottom):
Summary: This commit enables the custom module LSTM path for
FX graph mode static quantization. This has the same flow as eager
mode, which was already previously supported:
The main reason why custom module LSTM is not supported in FX
graph mode quantization today is because its inputs and outputs
are nested tuples, and existing constructs such as observers,
"quantize" nodes, and "dequantize" nodes do not understand how
to handle complex structures.
Note that the approach taken in this commit is only intended to
be a short-term solution highly tailored to the input and output
formats of custom module LSTM. In the future, for the longer-term
solution, we should design a more general QConfig that allows users
to specify complex input and output formats, and enable FX graph
mode quantization to understand arbitrary nested structures and
automatically infer how to transform the graph accordingly.
Context:
Today, in FX graph mode static quantization, custom modules are
assumed to have quantized inputs and quantized outputs, with the
exact dtypes derived from the associated QConfig (default quint8).
Since custom modules are currently not handled through the reference
model flow, their observer replacement logic are a little different
from normal operators:
In the last step, input observers are replaced with "quantize"
and output observers are replaced with "dequantize", in contrast
to other non-custom-module patterns where observers are replaced
with "quantize-dequantize" pairs instead. Note that, conceptually,
the output observer
obs1
is really just a DeQuantStub, since noobservation is actually needed.
Custom module LSTM:
The reason why custom module LSTM cannot be handled in the same
way is because, unlike other custom modules, its inputs and outputs
are nested tuples instead of single tensors. This is how the existing
custom module code would try to handle LSTMs:
However, this fails today because 1) we assume there is only one input
to the custom module, and so we never end up quantizing
hidden0
andhidden1
, and 2) the output observerobs1
is fed a tuple, which itdoes not understand how to handle.
Short-term fix:
This commit addresses the above by specifically handling the input
and output structures used by custom module LSTM. For the inputs,
we manually insert observers for
hidden0
andhidden1
to ensureall input tensors are quantized.
For the outputs, we split the tuple into its internal nodes, attach
a DeQuantStub to each node, and recombine these DeQuantStubs
according to the original structure. Finally, we must also reroute
consumers of the original LSTM tuple (and its internal nodes, e.g.
lstm[0]
) to these DeQuantStubs:Note that we choose to insert DeQuantStubs here instead of observers
because these will ultimately be replaced by "dequantize" nodes. This
matches the general custom module behavior, where output observers
are replaced only with "dequantize" nodes (as opposed to the normal
"quantize-dequantize" pair), since custom module outputs are assumed
to already be quantized. Using DeQuantStubs instead of observers also
simplifies the "dequantize" insertion logic. In the future, we should use
DeQuantStubs in place of output observers for custom modules in general.
Implementation:
In prepare, we make two special cases:
(1) When inserting output observers, we instead break the LSTM output
tuple into its internal nodes, insert a DeQuantStub after each node,
and recombine the tuple according to the original format
(in
_insert_dequant_stubs_for_custom_module_lstm_output
). Thisdiverges from, but is conceptually the same as, how we insert "special"
output observers that will be converted to "dequantize" nodes (during
convert) today for other custom modules. In the future, we should just
insert DeQuantStubs instead of these "special" observers for custom
modules in general.
(2) When inserting input observers for a node, we check whether it is
a consumer of LSTM by traversing up the subgraph created in (1)
(see
_maybe_get_custom_module_lstm_from_node_arg
).Once we identified what the previous node is, we reuse the existing code
in
maybe_insert_input_observer_for_arg_or_kwarg
to decide whether ornot to insert the input observer based on the output dtype of the
previous node.
As for convert:
(1) When converting the custom module, we change
quantize - dequantize - custom_module
toquantize - custom_module
for all inputs of custom module LSTM, including the hidden inputs (in
convert_custom_module
). This is consistent with how we do the samefor custom module inputs today, except it handles more than one input.
(2) We replace DeQuantStubs with "dequantize" nodes. This is not
really the special case as we plan to do this for all custom modules
in the future.
Test plan:
python test/test_quantization.py TestQuantizeFx.test_static_lstm
python test/test_quantization.py
TestQuantizeFx.test_static_lstm_consume_tuple
Reviewers: jerryzh168, vkuzo
Subscribers: jerryzh168, vkuzo