Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Quant][fx] Enable FX static quantization for LSTM #85068

Closed
wants to merge 14 commits into from

Conversation

andrewor14
Copy link
Contributor

@andrewor14 andrewor14 commented Sep 15, 2022

Stack from ghstack (oldest at bottom):

Summary: This commit enables the custom module LSTM path for
FX graph mode static quantization. This has the same flow as eager
mode, which was already previously supported:

     torch.nn.LSTM
           | (prepare_fx)
           v
torch.ao.nn.quantizable.LSTM
           | (convert_fx)
           v
 torch.ao.nn.quantized.LSTM

The main reason why custom module LSTM is not supported in FX
graph mode quantization today is because its inputs and outputs
are nested tuples, and existing constructs such as observers,
"quantize" nodes, and "dequantize" nodes do not understand how
to handle complex structures.

Note that the approach taken in this commit is only intended to
be a short-term solution highly tailored to the input and output
formats of custom module LSTM. In the future, for the longer-term
solution, we should design a more general QConfig that allows users
to specify complex input and output formats, and enable FX graph
mode quantization to understand arbitrary nested structures and
automatically infer how to transform the graph accordingly.

Context:

Today, in FX graph mode static quantization, custom modules are
assumed to have quantized inputs and quantized outputs, with the
exact dtypes derived from the associated QConfig (default quint8).
Since custom modules are currently not handled through the reference
model flow, their observer replacement logic are a little different
from normal operators:

# (1) Original model
input -> custom_module -> output

# (2) Observed model (after prepare)
input -> obs0 -> custom_module -> obs1 -> output

# (3) Quantized model (after convert)
input -> quant -> quantized_custom_module -> dequant -> output

In the last step, input observers are replaced with "quantize"
and output observers are replaced with "dequantize", in contrast
to other non-custom-module patterns where observers are replaced
with "quantize-dequantize" pairs instead. Note that, conceptually,
the output observer obs1 is really just a DeQuantStub, since no
observation is actually needed.

Custom module LSTM:

The reason why custom module LSTM cannot be handled in the same
way is because, unlike other custom modules, its inputs and outputs
are nested tuples instead of single tensors. This is how the existing
custom module code would try to handle LSTMs:

# (1) Original model
# input format: (input, (hidden0, hidden1))
# output format:  (output, (hidden0, hidden1))
 input -> lstm -> output
hidden0 -/    \-> hidden0
hidden1 -/    \-> hidden1

# (2) Observed model (after prepare)
 input -> obs0 -> lstm -> obs1  # fails
        hidden0 -/  # missing observer
        hidden1 -/  # missing observer

However, this fails today because 1) we assume there is only one input
to the custom module, and so we never end up quantizing hidden0 and
hidden1, and 2) the output observer obs1 is fed a tuple, which it
does not understand how to handle.

Short-term fix:

This commit addresses the above by specifically handling the input
and output structures used by custom module LSTM. For the inputs,
we manually insert observers for hidden0 and hidden1 to ensure
all input tensors are quantized.

For the outputs, we split the tuple into its internal nodes, attach
a DeQuantStub to each node, and recombine these DeQuantStubs
according to the original structure. Finally, we must also reroute
consumers of the original LSTM tuple (and its internal nodes, e.g.
lstm[0]) to these DeQuantStubs:

# (1) Original model
 input -> lstm -> output -> linear0
hidden0 -/    \-> hidden0 -> linear1
hidden1 -/    \-> hidden1 -> linear2

# (2) Observed model (after prepare)
 input -> obs0 -> lstm -> output -> dqstub -> linear0 -> obs3
hidden0 -> obs1 -/    \-> hidden0 -> dqstub -> linear1 -> obs4
hidden1 -> obs2 -/    \-> hidden1 -> dqstub -> linear2 -> obs5

# (3) Reference model (after convert)
 input -> quant -> qlstm -> output -> dequant -> linear0 -> quant -> dequant
hidden0 -> quant -/    \-> hidden0 -> dequant -> linear1 -> quant -> dequant
hidden1 -> quant -/    \-> hidden1 -> dequant -> linear2 -> quant -> dequant

# (4) Quantized model (after lowering)
 input -> quant -> qlstm -> output -> quantized_linear0 -> dequant
hidden0 -> quant -/    \-> hidden0 -> quantized_linear1 -> dequant
hidden1 -> quant -/    \-> hidden1 -> quantized_linear2 -> dequant

Note that we choose to insert DeQuantStubs here instead of observers
because these will ultimately be replaced by "dequantize" nodes. This
matches the general custom module behavior, where output observers
are replaced only with "dequantize" nodes (as opposed to the normal
"quantize-dequantize" pair), since custom module outputs are assumed
to already be quantized. Using DeQuantStubs instead of observers also
simplifies the "dequantize" insertion logic. In the future, we should use
DeQuantStubs in place of output observers for custom modules in general.

Implementation:

In prepare, we make two special cases:

(1) When inserting output observers, we instead break the LSTM output
tuple into its internal nodes, insert a DeQuantStub after each node,
and recombine the tuple according to the original format
(in _insert_dequant_stubs_for_custom_module_lstm_output). This
diverges from, but is conceptually the same as, how we insert "special"
output observers that will be converted to "dequantize" nodes (during
convert) today for other custom modules. In the future, we should just
insert DeQuantStubs instead of these "special" observers for custom
modules in general.

(2) When inserting input observers for a node, we check whether it is
a consumer of LSTM by traversing up the subgraph created in (1)
(see _maybe_get_custom_module_lstm_from_node_arg).
Once we identified what the previous node is, we reuse the existing code
in maybe_insert_input_observer_for_arg_or_kwarg to decide whether or
not to insert the input observer based on the output dtype of the
previous node.

As for convert:

(1) When converting the custom module, we change
quantize - dequantize - custom_module to quantize - custom_module
for all inputs of custom module LSTM, including the hidden inputs (in
convert_custom_module). This is consistent with how we do the same
for custom module inputs today, except it handles more than one input.

(2) We replace DeQuantStubs with "dequantize" nodes. This is not
really the special case as we plan to do this for all custom modules
in the future.

Test plan:
python test/test_quantization.py TestQuantizeFx.test_static_lstm
python test/test_quantization.py
TestQuantizeFx.test_static_lstm_consume_tuple

Reviewers: jerryzh168, vkuzo

Subscribers: jerryzh168, vkuzo

Summary: Static quantization for LSTM was previously only
supported in eager mode. This is because we had to go through
the custom module flow(torch.ao.nn.quantizable.LSTM). However,
this flow returned multiple outputs, and this is currently not
supported in FX graph mode quantization.

In this commit, we work around this limitation by making two
important modifications to this flow:

(1) First, the output of the LSTM node is in the form (result,
(hidden0, hidden1)), and each of the three internal nodes was
already observed, so there is actually no need to observe the
output again. This commit removes the extra output observer that
tried to observe the whole tuple, which failed previously.

(2) Second, the inputs of the LSTM node must be quantized, and
the outputs must be dequantized. This was not handled correctly
previously because FX graph mode quantization did not had special
logic for tuples. For the output in particular, we had to insert
ops to manually split the tuple, insert dequantize nodes, and
recombine them into the original format.

Note that the changes here are intended only as temporary hacks
to enable this flow. In the future, we should add better support
for dtype inference and tuple handling in FX graph mode quantization.

Test Plan:
python test/test_quantization.py TestQuantizeFx.test_static_lstm

Reviewers: jerryzh168, vkuzo

Subscribers: jerryzh168, vkuzo

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 15, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/85068

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures, 1 Pending

As of commit ad4ddce:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: quantization release notes category label Sep 15, 2022
andrewor14 added a commit that referenced this pull request Sep 15, 2022
Summary: Static quantization for LSTM was previously only
supported in eager mode. This is because we had to go through
the custom module flow(torch.ao.nn.quantizable.LSTM). However,
this flow returned multiple outputs, and this is currently not
supported in FX graph mode quantization.

In this commit, we work around this limitation by making two
important modifications to this flow:

(1) First, the output of the LSTM node is in the form (result,
(hidden0, hidden1)), and each of the three internal nodes was
already observed, so there is actually no need to observe the
output again. This commit removes the extra output observer that
tried to observe the whole tuple, which failed previously.

(2) Second, the inputs of the LSTM node must be quantized, and
the outputs must be dequantized. This was not handled correctly
previously because FX graph mode quantization did not had special
logic for tuples. For the output in particular, we had to insert
ops to manually split the tuple, insert dequantize nodes, and
recombine them into the original format.

Note that the changes here are intended only as temporary hacks
to enable this flow. In the future, we should add better support
for dtype inference and tuple handling in FX graph mode quantization.

Test Plan:
python test/test_quantization.py TestQuantizeFx.test_static_lstm

Reviewers: jerryzh168, vkuzo

Subscribers: jerryzh168, vkuzo

ghstack-source-id: 79ba2fad7bbc8a248b4acd86ade4ecbc36da38f9
Pull Request resolved: #85068
andrewor14 added a commit that referenced this pull request Sep 15, 2022
Summary: Static quantization for LSTM was previously only
supported in eager mode. This is because we had to go through
the custom module flow(torch.ao.nn.quantizable.LSTM). However,
this flow returned multiple outputs, and this is currently not
supported in FX graph mode quantization.

In this commit, we work around this limitation by making two
important modifications to this flow:

(1) First, the output of the LSTM node is in the form (result,
(hidden0, hidden1)), and each of the three internal nodes was
already observed, so there is actually no need to observe the
output again. This commit removes the extra output observer that
tried to observe the whole tuple, which failed previously.

(2) Second, the inputs of the LSTM node must be quantized, and
the outputs must be dequantized. This was not handled correctly
previously because FX graph mode quantization did not had special
logic for tuples. For the output in particular, we had to insert
ops to manually split the tuple, insert dequantize nodes, and
recombine them into the original format.

Note that the changes here are intended only as temporary hacks
to enable this flow. In the future, we should add better support
for dtype inference and tuple handling in FX graph mode quantization.

Test Plan:
python test/test_quantization.py TestQuantizeFx.test_static_lstm

Reviewers: jerryzh168, vkuzo

Subscribers: jerryzh168, vkuzo

ghstack-source-id: 79ba2fad7bbc8a248b4acd86ade4ecbc36da38f9
Pull Request resolved: #85068

m = MyModel()
qconfig_mapping = get_default_qconfig_mapping()
prepare_custom_config = PrepareCustomConfig() \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe we can have a get_default_prepare_custom_config as well (can be in a separate PR)

Summary: Static quantization for LSTM was previously only
supported in eager mode. This is because we had to go through
the custom module flow(torch.ao.nn.quantizable.LSTM). However,
this flow returned multiple outputs, and this is currently not
supported in FX graph mode quantization.

In this commit, we work around this limitation by making two
important modifications to this flow:

(1) First, the output of the LSTM node is in the form (result,
(hidden0, hidden1)), and each of the three internal nodes was
already observed, so there is actually no need to observe the
output again. This commit removes the extra output observer that
tried to observe the whole tuple, which failed previously.

(2) Second, the inputs of the LSTM node must be quantized, and
the outputs must be dequantized. This was not handled correctly
previously because FX graph mode quantization did not had special
logic for tuples. For the output in particular, we had to insert
ops to manually split the tuple, insert dequantize nodes, and
recombine them into the original format.

Note that the changes here are intended only as temporary hacks
to enable this flow. In the future, we should add better support
for dtype inference and tuple handling in FX graph mode quantization.

Test Plan:
python test/test_quantization.py TestQuantizeFx.test_static_lstm

Reviewers: jerryzh168, vkuzo

Subscribers: jerryzh168, vkuzo

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 15, 2022
Summary: Static quantization for LSTM was previously only
supported in eager mode. This is because we had to go through
the custom module flow(torch.ao.nn.quantizable.LSTM). However,
this flow returned multiple outputs, and this is currently not
supported in FX graph mode quantization.

In this commit, we work around this limitation by making two
important modifications to this flow:

(1) First, the output of the LSTM node is in the form (result,
(hidden0, hidden1)), and each of the three internal nodes was
already observed, so there is actually no need to observe the
output again. This commit removes the extra output observer that
tried to observe the whole tuple, which failed previously.

(2) Second, the inputs of the LSTM node must be quantized, and
the outputs must be dequantized. This was not handled correctly
previously because FX graph mode quantization did not had special
logic for tuples. For the output in particular, we had to insert
ops to manually split the tuple, insert dequantize nodes, and
recombine them into the original format.

Note that the changes here are intended only as temporary hacks
to enable this flow. In the future, we should add better support
for dtype inference and tuple handling in FX graph mode quantization.

Test Plan:
python test/test_quantization.py TestQuantizeFx.test_static_lstm

Reviewers: jerryzh168, vkuzo

Subscribers: jerryzh168, vkuzo

ghstack-source-id: 65eb10fb9fe3729c3863ffa393ea3ecf0b34e4f4
Pull Request resolved: #85068
@andrewor14 andrewor14 marked this pull request as draft September 17, 2022 00:26
andrewor14 added a commit that referenced this pull request Sep 18, 2022
Summary: Static quantization for LSTM was previously only
supported in eager mode. This is because we had to go through
the custom module flow(torch.ao.nn.quantizable.LSTM). However,
this flow returned multiple outputs, and this is currently not
supported in FX graph mode quantization.

In this commit, we work around this limitation by making two
important modifications to this flow:

(1) First, the output of the LSTM node is in the form (result,
(hidden0, hidden1)), and each of the three internal nodes was
already observed, so there is actually no need to observe the
output again. This commit removes the extra output observer that
tried to observe the whole tuple, which failed previously.

(2) Second, the inputs of the LSTM node must be quantized, and
the outputs must be dequantized. This was not handled correctly
previously because FX graph mode quantization did not had special
logic for tuples. For the output in particular, we had to insert
ops to manually split the tuple, insert dequantize nodes, and
recombine them into the original format.

Note that the changes here are intended only as temporary hacks
to enable this flow. In the future, we should add better support
for dtype inference and tuple handling in FX graph mode quantization.

Test Plan:
python test/test_quantization.py TestQuantizeFx.test_static_lstm

Reviewers: jerryzh168, vkuzo

Subscribers: jerryzh168, vkuzo

ghstack-source-id: 65eb10fb9fe3729c3863ffa393ea3ecf0b34e4f4
Pull Request resolved: #85068
andrewor14 added a commit that referenced this pull request Sep 18, 2022
Summary: Static quantization for LSTM was previously only
supported in eager mode. This is because we had to go through
the custom module flow(torch.ao.nn.quantizable.LSTM). However,
this flow returned multiple outputs, and this is currently not
supported in FX graph mode quantization.

In this commit, we work around this limitation by making two
important modifications to this flow:

(1) First, the output of the LSTM node is in the form (result,
(hidden0, hidden1)), and each of the three internal nodes was
already observed, so there is actually no need to observe the
output again. This commit removes the extra output observer that
tried to observe the whole tuple, which failed previously.

(2) Second, the inputs of the LSTM node must be quantized, and
the outputs must be dequantized. This was not handled correctly
previously because FX graph mode quantization did not had special
logic for tuples. For the output in particular, we had to insert
ops to manually split the tuple, insert dequantize nodes, and
recombine them into the original format.

Note that the changes here are intended only as temporary hacks
to enable this flow. In the future, we should add better support
for dtype inference and tuple handling in FX graph mode quantization.

Test Plan:
python test/test_quantization.py TestQuantizeFx.test_static_lstm

Reviewers: jerryzh168, vkuzo

Subscribers: jerryzh168, vkuzo

ghstack-source-id: 65eb10fb9fe3729c3863ffa393ea3ecf0b34e4f4
Pull Request resolved: #85068
andrewor14 added a commit that referenced this pull request Sep 18, 2022
Summary: Static quantization for LSTM was previously only
supported in eager mode. This is because we had to go through
the custom module flow(torch.ao.nn.quantizable.LSTM). However,
this flow returned multiple outputs, and this is currently not
supported in FX graph mode quantization.

In this commit, we work around this limitation by making two
important modifications to this flow:

(1) First, the output of the LSTM node is in the form (result,
(hidden0, hidden1)), and each of the three internal nodes was
already observed, so there is actually no need to observe the
output again. This commit removes the extra output observer that
tried to observe the whole tuple, which failed previously.

(2) Second, the inputs of the LSTM node must be quantized, and
the outputs must be dequantized. This was not handled correctly
previously because FX graph mode quantization did not had special
logic for tuples. For the output in particular, we had to insert
ops to manually split the tuple, insert dequantize nodes, and
recombine them into the original format.

Note that the changes here are intended only as temporary hacks
to enable this flow. In the future, we should add better support
for dtype inference and tuple handling in FX graph mode quantization.

Test Plan:
python test/test_quantization.py TestQuantizeFx.test_static_lstm

Reviewers: jerryzh168, vkuzo

Subscribers: jerryzh168, vkuzo

ghstack-source-id: 65eb10fb9fe3729c3863ffa393ea3ecf0b34e4f4
Pull Request resolved: #85068
andrewor14 added a commit that referenced this pull request Sep 18, 2022
Summary: Static quantization for LSTM was previously only
supported in eager mode. This is because we had to go through
the custom module flow(torch.ao.nn.quantizable.LSTM). However,
this flow returned multiple outputs, and this is currently not
supported in FX graph mode quantization.

In this commit, we work around this limitation by making two
important modifications to this flow:

(1) First, the output of the LSTM node is in the form (result,
(hidden0, hidden1)), and each of the three internal nodes was
already observed, so there is actually no need to observe the
output again. This commit removes the extra output observer that
tried to observe the whole tuple, which failed previously.

(2) Second, the inputs of the LSTM node must be quantized, and
the outputs must be dequantized. This was not handled correctly
previously because FX graph mode quantization did not had special
logic for tuples. For the output in particular, we had to insert
ops to manually split the tuple, insert dequantize nodes, and
recombine them into the original format.

Note that the changes here are intended only as temporary hacks
to enable this flow. In the future, we should add better support
for dtype inference and tuple handling in FX graph mode quantization.

Test Plan:
python test/test_quantization.py TestQuantizeFx.test_static_lstm

Reviewers: jerryzh168, vkuzo

Subscribers: jerryzh168, vkuzo

ghstack-source-id: 65eb10fb9fe3729c3863ffa393ea3ecf0b34e4f4
Pull Request resolved: #85068
andrewor14 added a commit that referenced this pull request Sep 18, 2022
Summary: (TODO: write this)

Test Plan:
python test/test_quantization.py TestQuantizeFx.test_static_lstm

Reviewers: jerryzh168, vkuzo

Subscribers: jerryzh168, vkuzo

ghstack-source-id: 65eb10fb9fe3729c3863ffa393ea3ecf0b34e4f4
Pull Request resolved: #85068
andrewor14 added a commit that referenced this pull request Sep 18, 2022
Summary: (TODO: write this)

Test Plan:
python test/test_quantization.py TestQuantizeFx.test_static_lstm

Reviewers: jerryzh168, vkuzo

Subscribers: jerryzh168, vkuzo

ghstack-source-id: 65eb10fb9fe3729c3863ffa393ea3ecf0b34e4f4
Pull Request resolved: #85068
andrewor14 added a commit that referenced this pull request Sep 18, 2022
Summary: (TODO: write this)

Test Plan:
python test/test_quantization.py TestQuantizeFx.test_static_lstm

Reviewers: jerryzh168, vkuzo

Subscribers: jerryzh168, vkuzo

ghstack-source-id: 65eb10fb9fe3729c3863ffa393ea3ecf0b34e4f4
Pull Request resolved: #85068
andrewor14 added a commit that referenced this pull request Sep 18, 2022
Summary: (TODO: write this)

Test Plan:
python test/test_quantization.py TestQuantizeFx.test_static_lstm

Reviewers: jerryzh168, vkuzo

Subscribers: jerryzh168, vkuzo

ghstack-source-id: 65eb10fb9fe3729c3863ffa393ea3ecf0b34e4f4
Pull Request resolved: #85068
andrewor14 added a commit that referenced this pull request Sep 18, 2022
Summary: (TODO: write this)

Test Plan:
python test/test_quantization.py TestQuantizeFx.test_static_lstm

Reviewers: jerryzh168, vkuzo

Subscribers: jerryzh168, vkuzo

ghstack-source-id: 65eb10fb9fe3729c3863ffa393ea3ecf0b34e4f4
Pull Request resolved: #85068
andrewor14 added a commit that referenced this pull request Sep 18, 2022
Summary: (TODO: write this)

Test Plan:
python test/test_quantization.py TestQuantizeFx.test_static_lstm
python test/test_quantization.py
TestQuantizeFx.test_static_lstm_consume_tuple

Reviewers: jerryzh168, vkuzo

Subscribers: jerryzh168, vkuzo

ghstack-source-id: 65eb10fb9fe3729c3863ffa393ea3ecf0b34e4f4
Pull Request resolved: #85068
andrewor14 added a commit that referenced this pull request Sep 18, 2022
**Summary:** This commit enables the custom module LSTM path for
FX graph mode static quantization. This has the same flow as eager
mode, which was already previously supported:

```
     torch.nn.LSTM
           | (prepare_fx)
           v
torch.ao.nn.quantizable.LSTM
           | (convert_fx)
           v
 torch.ao.nn.quantized.LSTM
```

Context:

Today, in FX graph mode static quantization, custom modules are
assumed to have quantized inputs and quantized outputs, with the
exact dtypes derived from the associated QConfig (default quint8).
Since custom modules are currently not handled through the reference
model flow, their observer replacement logic are a little different:

```
\# (1) Original model
input -> custom_module -> output

\# (2) Observed model (after prepare)
input -> obs0 -> custom_module -> obs1 -> output

\# (3) Quantized model (after convert)
input -> quant -> quantized_custom_module -> dequant -> output
```

In the last step, input observers are replaced with "quantize"
and output observers are replaced with "dequantize", in contrast
to other non-custom-module patterns where observers are replaced
with "quantize-dequantize" pairs instead.

**Test Plan:**
python test/test_quantization.py TestQuantizeFx.test_static_lstm
python test/test_quantization.py
TestQuantizeFx.test_static_lstm_consume_tuple

**Reviewers:** jerryzh168, vkuzo

**Subscribers:** jerryzh168, vkuzo

ghstack-source-id: 65eb10fb9fe3729c3863ffa393ea3ecf0b34e4f4
Pull Request resolved: #85068
andrewor14 added a commit that referenced this pull request Sep 19, 2022
**Summary:** This commit enables the custom module LSTM path for
FX graph mode static quantization. This has the same flow as eager
mode, which was already previously supported:

```
     torch.nn.LSTM
           | (prepare_fx)
           v
torch.ao.nn.quantizable.LSTM
           | (convert_fx)
           v
 torch.ao.nn.quantized.LSTM
```

**Context:**

Today, in FX graph mode static quantization, custom modules are
assumed to have quantized inputs and quantized outputs, with the
exact dtypes derived from the associated QConfig (default quint8).
Since custom modules are currently not handled through the reference
model flow, their observer replacement logic are a little different:

```
\# (1) Original model
input -> custom_module -> output

\# (2) Observed model (after prepare)
input -> obs0 -> custom_module -> obs1 -> output

\# (3) Quantized model (after convert)
input -> quant -> quantized_custom_module -> dequant -> output
```

In the last step, input observers are replaced with "quantize"
and output observers are replaced with "dequantize", in contrast
to other non-custom-module patterns where observers are replaced
with "quantize-dequantize" pairs instead. Note that, conceptually,
the output observer `obs1` is really just a `DeQuantStub`.

**Custom module LSTM:**

The reason why custom module LSTM cannot be handled the same
way is because, unlike other custom modules, its inputs and outputs
are *nested tuples* instead of single tensors. This is how the existing
custom module code would try to handle LSTMs:

```
\# (1) Original model
\# input format: (input, (hidden0, hidden1))
\# output format:  (output, (hidden0, hidden1))
 input -> lstm -> output
hidden0 -/    \-> hidden0
hidden1 -/    \-> hidden1

\# (2) Observed model (after prepare)
 input -> obs0 -> lstm -> obs1  # fails
        hidden0 -/  # missing observer
        hidden1 -/  # missing observer
```

However, this fails today because 1) we assume there is only one input
to the custom module, and so we never end up quantizing `hidden0` and
`hidden1`, and 2) the output observer `obs1` is fed a tuple, which it
does not understand how to handle.

The ideal fix for this would be to design a more general QConfig that
allows users to specify complex input and output formats. This would
enable FX graph mode quantization to understand arbitrary nested
structures and automatically infer how to transform the graph accordingly.
Until this is available, users will have to rely on the short-term fix
introduced in this commit.

**Short-term fix:**

This commit addresses the above by specifically handling the input
and output structures used by custom module LSTM. For the inputs,
we manually insert observers for `hidden0` and `hidden1` to ensure
all input tensors are quantized.

For the outputs, we split the tuple into its internal nodes, attach
a `DeQuantStub` to each node, and recombine these `DeQuantStub`s
according to the original structure. Finally, we must also reroute
consumers of the original LSTM tuple (and its internal nodes, e.g.
lstm[0]) to these `DeQuantStub`s.

```
\# (1) Original model
 input -> lstm -> output -> linear0
hidden0 -/    \-> hidden0 -> linear1
hidden1 -/    \-> hidden1 -> linear2

\# (2) Observed model (after prepare)
 input -> obs0 -> lstm -> dqstub -> output -> linear0 -> obs3
hidden0 -> obs1 -/    \-> dqstub -> hidden0 -> linear1 -> obs4
hidden1 -> obs2 -/    \-> dqstub -> hidden1 -> linear2 -> obs5

\# (3) Reference model (after convert)
 input -> quant -> qlstm -> dequant -> linear0 -> quant -> dequant
hidden0 -> quant -/    \-> dequant -> linear1 -> quant -> dequant
hidden1 -> quant -/    \-> dequant -> linear2 -> quant -> dequant

\# (4) Quantized model (after lowering)
 input -> quant -> qlstm -> quantized_linear0 -> dequant
hidden0 -> quant -/    \-> quantized_linear1 -> dequant
hidden1 -> quant -/    \-> quantized_linear2 -> dequant
```

Note that we choose to insert `DeQuantStub`s here instead of observers
because these will ultimately be replaced by "dequantize" nodes.
This matches the general custom module behavior, where output observers
are replaced only with "dequantize" nodes (as opposed to the normal
"quantize-dequantize" pair) since custom module outputs are assumed
to already be quantized. In the future, we should use `DeQuantStub`s
in place of output observers for custom modules in general.

**Test plan:**
python test/test_quantization.py TestQuantizeFx.test_static_lstm
python test/test_quantization.py
TestQuantizeFx.test_static_lstm_consume_tuple

**Reviewers:** jerryzh168, vkuzo

**Subscribers:** jerryzh168, vkuzo

ghstack-source-id: 65eb10fb9fe3729c3863ffa393ea3ecf0b34e4f4
Pull Request resolved: #85068
andrewor14 added a commit that referenced this pull request Sep 21, 2022
**Summary:** This commit enables the custom module LSTM path for
FX graph mode static quantization. This has the same flow as eager
mode, which was already previously supported:

```
     torch.nn.LSTM
           | (prepare_fx)
           v
torch.ao.nn.quantizable.LSTM
           | (convert_fx)
           v
 torch.ao.nn.quantized.LSTM
```

The main reason why custom module LSTM is not supported in FX
graph mode quantization today is because its inputs and outputs
are nested tuples, and existing constructs such as observers,
"quantize" nodes, and "dequantize" nodes do not understand how
to handle complex structures.

Note that the approach taken in this commit is only intended to
be a short-term solution highly tailored to the input and output
formats of custom module LSTM. In the future, for the longer-term
solution, we should design a more general QConfig that allows users
to specify complex input and output formats, and enable FX graph
mode quantization to understand arbitrary nested structures and
automatically infer how to transform the graph accordingly.

**Context:**

Today, in FX graph mode static quantization, custom modules are
assumed to have quantized inputs and quantized outputs, with the
exact dtypes derived from the associated QConfig (default quint8).
Since custom modules are currently not handled through the reference
model flow, their observer replacement logic are a little different
from normal operators:

```
\# (1) Original model
input -> custom_module -> output

\# (2) Observed model (after prepare)
input -> obs0 -> custom_module -> obs1 -> output

\# (3) Quantized model (after convert)
input -> quant -> quantized_custom_module -> dequant -> output
```

In the last step, input observers are replaced with "quantize"
and output observers are replaced with "dequantize", in contrast
to other non-custom-module patterns where observers are replaced
with "quantize-dequantize" pairs instead. Note that, conceptually,
the output observer `obs1` is really just a DeQuantStub, since no
observation is actually needed.

**Custom module LSTM:**

The reason why custom module LSTM cannot be handled in the same
way is because, unlike other custom modules, its inputs and outputs
are nested tuples instead of single tensors. This is how the existing
custom module code would try to handle LSTMs:

```
\# (1) Original model
\# input format: (input, (hidden0, hidden1))
\# output format:  (output, (hidden0, hidden1))
 input -> lstm -> output
hidden0 -/    \-> hidden0
hidden1 -/    \-> hidden1

\# (2) Observed model (after prepare)
 input -> obs0 -> lstm -> obs1  # fails
        hidden0 -/  # missing observer
        hidden1 -/  # missing observer
```

However, this fails today because 1) we assume there is only one input
to the custom module, and so we never end up quantizing `hidden0` and
`hidden1`, and 2) the output observer `obs1` is fed a tuple, which it
does not understand how to handle.

**Short-term fix:**

This commit addresses the above by specifically handling the input
and output structures used by custom module LSTM. For the inputs,
we manually insert observers for `hidden0` and `hidden1` to ensure
all input tensors are quantized.

For the outputs, we split the tuple into its internal nodes, attach
a DeQuantStub to each node, and recombine these DeQuantStubs
according to the original structure. Finally, we must also reroute
consumers of the original LSTM tuple (and its internal nodes, e.g.
`lstm[0]`) to these DeQuantStubs:

```
\# (1) Original model
 input -> lstm -> output -> linear0
hidden0 -/    \-> hidden0 -> linear1
hidden1 -/    \-> hidden1 -> linear2

\# (2) Observed model (after prepare)
 input -> obs0 -> lstm -> output -> dqstub -> linear0 -> obs3
hidden0 -> obs1 -/    \-> hidden0 -> dqstub -> linear1 -> obs4
hidden1 -> obs2 -/    \-> hidden1 -> dqstub -> linear2 -> obs5

\# (3) Reference model (after convert)
 input -> quant -> qlstm -> output -> dequant -> linear0 -> quant -> dequant
hidden0 -> quant -/    \-> hidden0 -> dequant -> linear1 -> quant -> dequant
hidden1 -> quant -/    \-> hidden1 -> dequant -> linear2 -> quant -> dequant

\# (4) Quantized model (after lowering)
 input -> quant -> qlstm -> output -> quantized_linear0 -> dequant
hidden0 -> quant -/    \-> hidden0 -> quantized_linear1 -> dequant
hidden1 -> quant -/    \-> hidden1 -> quantized_linear2 -> dequant
```

Note that we choose to insert DeQuantStubs here instead of observers
because these will ultimately be replaced by "dequantize" nodes. This
matches the general custom module behavior, where output observers
are replaced only with "dequantize" nodes (as opposed to the normal
"quantize-dequantize" pair), since custom module outputs are assumed
to already be quantized. Using DeQuantStubs instead of observers also
simplifies the "dequantize" insertion logic. In the future, we should use
DeQuantStubs in place of output observers for custom modules in general.

**Test plan:**
python test/test_quantization.py TestQuantizeFx.test_static_lstm
python test/test_quantization.py
TestQuantizeFx.test_static_lstm_consume_tuple

**Reviewers:** jerryzh168, vkuzo

**Subscribers:** jerryzh168, vkuzo

ghstack-source-id: e6a596fed38a7b0a0add4a5b31b26c0c5222c02f
Pull Request resolved: #85068
andrewor14 added a commit that referenced this pull request Sep 21, 2022
**Summary:** This commit enables the custom module LSTM path for
FX graph mode static quantization. This has the same flow as eager
mode, which was already previously supported:

```
     torch.nn.LSTM
           | (prepare_fx)
           v
torch.ao.nn.quantizable.LSTM
           | (convert_fx)
           v
 torch.ao.nn.quantized.LSTM
```

The main reason why custom module LSTM is not supported in FX
graph mode quantization today is because its inputs and outputs
are nested tuples, and existing constructs such as observers,
"quantize" nodes, and "dequantize" nodes do not understand how
to handle complex structures.

Note that the approach taken in this commit is only intended to
be a short-term solution highly tailored to the input and output
formats of custom module LSTM. In the future, for the longer-term
solution, we should design a more general QConfig that allows users
to specify complex input and output formats, and enable FX graph
mode quantization to understand arbitrary nested structures and
automatically infer how to transform the graph accordingly.

**Context:**

Today, in FX graph mode static quantization, custom modules are
assumed to have quantized inputs and quantized outputs, with the
exact dtypes derived from the associated QConfig (default quint8).
Since custom modules are currently not handled through the reference
model flow, their observer replacement logic are a little different
from normal operators:

```
\# (1) Original model
input -> custom_module -> output

\# (2) Observed model (after prepare)
input -> obs0 -> custom_module -> obs1 -> output

\# (3) Quantized model (after convert)
input -> quant -> quantized_custom_module -> dequant -> output
```

In the last step, input observers are replaced with "quantize"
and output observers are replaced with "dequantize", in contrast
to other non-custom-module patterns where observers are replaced
with "quantize-dequantize" pairs instead. Note that, conceptually,
the output observer `obs1` is really just a DeQuantStub, since no
observation is actually needed.

**Custom module LSTM:**

The reason why custom module LSTM cannot be handled in the same
way is because, unlike other custom modules, its inputs and outputs
are nested tuples instead of single tensors. This is how the existing
custom module code would try to handle LSTMs:

```
\# (1) Original model
\# input format: (input, (hidden0, hidden1))
\# output format:  (output, (hidden0, hidden1))
 input -> lstm -> output
hidden0 -/    \-> hidden0
hidden1 -/    \-> hidden1

\# (2) Observed model (after prepare)
 input -> obs0 -> lstm -> obs1  # fails
        hidden0 -/  # missing observer
        hidden1 -/  # missing observer
```

However, this fails today because 1) we assume there is only one input
to the custom module, and so we never end up quantizing `hidden0` and
`hidden1`, and 2) the output observer `obs1` is fed a tuple, which it
does not understand how to handle.

**Short-term fix:**

This commit addresses the above by specifically handling the input
and output structures used by custom module LSTM. For the inputs,
we manually insert observers for `hidden0` and `hidden1` to ensure
all input tensors are quantized.

For the outputs, we split the tuple into its internal nodes, attach
a DeQuantStub to each node, and recombine these DeQuantStubs
according to the original structure. Finally, we must also reroute
consumers of the original LSTM tuple (and its internal nodes, e.g.
`lstm[0]`) to these DeQuantStubs:

```
\# (1) Original model
 input -> lstm -> output -> linear0
hidden0 -/    \-> hidden0 -> linear1
hidden1 -/    \-> hidden1 -> linear2

\# (2) Observed model (after prepare)
 input -> obs0 -> lstm -> output -> dqstub -> linear0 -> obs3
hidden0 -> obs1 -/    \-> hidden0 -> dqstub -> linear1 -> obs4
hidden1 -> obs2 -/    \-> hidden1 -> dqstub -> linear2 -> obs5

\# (3) Reference model (after convert)
 input -> quant -> qlstm -> output -> dequant -> linear0 -> quant -> dequant
hidden0 -> quant -/    \-> hidden0 -> dequant -> linear1 -> quant -> dequant
hidden1 -> quant -/    \-> hidden1 -> dequant -> linear2 -> quant -> dequant

\# (4) Quantized model (after lowering)
 input -> quant -> qlstm -> output -> quantized_linear0 -> dequant
hidden0 -> quant -/    \-> hidden0 -> quantized_linear1 -> dequant
hidden1 -> quant -/    \-> hidden1 -> quantized_linear2 -> dequant
```

Note that we choose to insert DeQuantStubs here instead of observers
because these will ultimately be replaced by "dequantize" nodes. This
matches the general custom module behavior, where output observers
are replaced only with "dequantize" nodes (as opposed to the normal
"quantize-dequantize" pair), since custom module outputs are assumed
to already be quantized. Using DeQuantStubs instead of observers also
simplifies the "dequantize" insertion logic. In the future, we should use
DeQuantStubs in place of output observers for custom modules in general.

**Test plan:**
python test/test_quantization.py TestQuantizeFx.test_static_lstm
python test/test_quantization.py
TestQuantizeFx.test_static_lstm_consume_tuple

**Reviewers:** jerryzh168, vkuzo

**Subscribers:** jerryzh168, vkuzo

ghstack-source-id: e6a596fed38a7b0a0add4a5b31b26c0c5222c02f
Pull Request resolved: #85068
**Summary:** This commit enables the custom module LSTM path for
FX graph mode static quantization. This has the same flow as eager
mode, which was already previously supported:

```
     torch.nn.LSTM
           | (prepare_fx)
           v
torch.ao.nn.quantizable.LSTM
           | (convert_fx)
           v
 torch.ao.nn.quantized.LSTM
```

The main reason why custom module LSTM is not supported in FX
graph mode quantization today is because its inputs and outputs
are nested tuples, and existing constructs such as observers,
"quantize" nodes, and "dequantize" nodes do not understand how
to handle complex structures.

Note that the approach taken in this commit is only intended to
be a short-term solution highly tailored to the input and output
formats of custom module LSTM. In the future, for the longer-term
solution, we should design a more general QConfig that allows users
to specify complex input and output formats, and enable FX graph
mode quantization to understand arbitrary nested structures and
automatically infer how to transform the graph accordingly.

**Context:**

Today, in FX graph mode static quantization, custom modules are
assumed to have quantized inputs and quantized outputs, with the
exact dtypes derived from the associated QConfig (default quint8).
Since custom modules are currently not handled through the reference
model flow, their observer replacement logic are a little different
from normal operators:

```
# (1) Original model
input -> custom_module -> output

# (2) Observed model (after prepare)
input -> obs0 -> custom_module -> obs1 -> output

# (3) Quantized model (after convert)
input -> quant -> quantized_custom_module -> dequant -> output
```

In the last step, input observers are replaced with "quantize"
and output observers are replaced with "dequantize", in contrast
to other non-custom-module patterns where observers are replaced
with "quantize-dequantize" pairs instead. Note that, conceptually,
the output observer `obs1` is really just a DeQuantStub, since no
observation is actually needed.

**Custom module LSTM:**

The reason why custom module LSTM cannot be handled in the same
way is because, unlike other custom modules, its inputs and outputs
are nested tuples instead of single tensors. This is how the existing
custom module code would try to handle LSTMs:

```
# (1) Original model
# input format: (input, (hidden0, hidden1))
# output format:  (output, (hidden0, hidden1))
 input -> lstm -> output
hidden0 -/    \-> hidden0
hidden1 -/    \-> hidden1

# (2) Observed model (after prepare)
 input -> obs0 -> lstm -> obs1  # fails
        hidden0 -/  # missing observer
        hidden1 -/  # missing observer
```

However, this fails today because 1) we assume there is only one input
to the custom module, and so we never end up quantizing `hidden0` and
`hidden1`, and 2) the output observer `obs1` is fed a tuple, which it
does not understand how to handle.

**Short-term fix:**

This commit addresses the above by specifically handling the input
and output structures used by custom module LSTM. For the inputs,
we manually insert observers for `hidden0` and `hidden1` to ensure
all input tensors are quantized.

For the outputs, we split the tuple into its internal nodes, attach
a DeQuantStub to each node, and recombine these DeQuantStubs
according to the original structure. Finally, we must also reroute
consumers of the original LSTM tuple (and its internal nodes, e.g.
`lstm[0]`) to these DeQuantStubs:

```
# (1) Original model
 input -> lstm -> output -> linear0
hidden0 -/    \-> hidden0 -> linear1
hidden1 -/    \-> hidden1 -> linear2

# (2) Observed model (after prepare)
 input -> obs0 -> lstm -> output -> dqstub -> linear0 -> obs3
hidden0 -> obs1 -/    \-> hidden0 -> dqstub -> linear1 -> obs4
hidden1 -> obs2 -/    \-> hidden1 -> dqstub -> linear2 -> obs5

# (3) Reference model (after convert)
 input -> quant -> qlstm -> output -> dequant -> linear0 -> quant -> dequant
hidden0 -> quant -/    \-> hidden0 -> dequant -> linear1 -> quant -> dequant
hidden1 -> quant -/    \-> hidden1 -> dequant -> linear2 -> quant -> dequant

# (4) Quantized model (after lowering)
 input -> quant -> qlstm -> output -> quantized_linear0 -> dequant
hidden0 -> quant -/    \-> hidden0 -> quantized_linear1 -> dequant
hidden1 -> quant -/    \-> hidden1 -> quantized_linear2 -> dequant
```

Note that we choose to insert DeQuantStubs here instead of observers
because these will ultimately be replaced by "dequantize" nodes. This
matches the general custom module behavior, where output observers
are replaced only with "dequantize" nodes (as opposed to the normal
"quantize-dequantize" pair), since custom module outputs are assumed
to already be quantized. Using DeQuantStubs instead of observers also
simplifies the "dequantize" insertion logic. In the future, we should use
DeQuantStubs in place of output observers for custom modules in general.

**Test plan:**
python test/test_quantization.py TestQuantizeFx.test_static_lstm
python test/test_quantization.py
TestQuantizeFx.test_static_lstm_consume_tuple

**Reviewers:** jerryzh168, vkuzo

**Subscribers:** jerryzh168, vkuzo

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 21, 2022
**Summary:** This commit enables the custom module LSTM path for
FX graph mode static quantization. This has the same flow as eager
mode, which was already previously supported:

```
     torch.nn.LSTM
           | (prepare_fx)
           v
torch.ao.nn.quantizable.LSTM
           | (convert_fx)
           v
 torch.ao.nn.quantized.LSTM
```

The main reason why custom module LSTM is not supported in FX
graph mode quantization today is because its inputs and outputs
are nested tuples, and existing constructs such as observers,
"quantize" nodes, and "dequantize" nodes do not understand how
to handle complex structures.

Note that the approach taken in this commit is only intended to
be a short-term solution highly tailored to the input and output
formats of custom module LSTM. In the future, for the longer-term
solution, we should design a more general QConfig that allows users
to specify complex input and output formats, and enable FX graph
mode quantization to understand arbitrary nested structures and
automatically infer how to transform the graph accordingly.

**Context:**

Today, in FX graph mode static quantization, custom modules are
assumed to have quantized inputs and quantized outputs, with the
exact dtypes derived from the associated QConfig (default quint8).
Since custom modules are currently not handled through the reference
model flow, their observer replacement logic are a little different
from normal operators:

```
\# (1) Original model
input -> custom_module -> output

\# (2) Observed model (after prepare)
input -> obs0 -> custom_module -> obs1 -> output

\# (3) Quantized model (after convert)
input -> quant -> quantized_custom_module -> dequant -> output
```

In the last step, input observers are replaced with "quantize"
and output observers are replaced with "dequantize", in contrast
to other non-custom-module patterns where observers are replaced
with "quantize-dequantize" pairs instead. Note that, conceptually,
the output observer `obs1` is really just a DeQuantStub, since no
observation is actually needed.

**Custom module LSTM:**

The reason why custom module LSTM cannot be handled in the same
way is because, unlike other custom modules, its inputs and outputs
are nested tuples instead of single tensors. This is how the existing
custom module code would try to handle LSTMs:

```
\# (1) Original model
\# input format: (input, (hidden0, hidden1))
\# output format:  (output, (hidden0, hidden1))
 input -> lstm -> output
hidden0 -/    \-> hidden0
hidden1 -/    \-> hidden1

\# (2) Observed model (after prepare)
 input -> obs0 -> lstm -> obs1  # fails
        hidden0 -/  # missing observer
        hidden1 -/  # missing observer
```

However, this fails today because 1) we assume there is only one input
to the custom module, and so we never end up quantizing `hidden0` and
`hidden1`, and 2) the output observer `obs1` is fed a tuple, which it
does not understand how to handle.

**Short-term fix:**

This commit addresses the above by specifically handling the input
and output structures used by custom module LSTM. For the inputs,
we manually insert observers for `hidden0` and `hidden1` to ensure
all input tensors are quantized.

For the outputs, we split the tuple into its internal nodes, attach
a DeQuantStub to each node, and recombine these DeQuantStubs
according to the original structure. Finally, we must also reroute
consumers of the original LSTM tuple (and its internal nodes, e.g.
`lstm[0]`) to these DeQuantStubs:

```
\# (1) Original model
 input -> lstm -> output -> linear0
hidden0 -/    \-> hidden0 -> linear1
hidden1 -/    \-> hidden1 -> linear2

\# (2) Observed model (after prepare)
 input -> obs0 -> lstm -> output -> dqstub -> linear0 -> obs3
hidden0 -> obs1 -/    \-> hidden0 -> dqstub -> linear1 -> obs4
hidden1 -> obs2 -/    \-> hidden1 -> dqstub -> linear2 -> obs5

\# (3) Reference model (after convert)
 input -> quant -> qlstm -> output -> dequant -> linear0 -> quant -> dequant
hidden0 -> quant -/    \-> hidden0 -> dequant -> linear1 -> quant -> dequant
hidden1 -> quant -/    \-> hidden1 -> dequant -> linear2 -> quant -> dequant

\# (4) Quantized model (after lowering)
 input -> quant -> qlstm -> output -> quantized_linear0 -> dequant
hidden0 -> quant -/    \-> hidden0 -> quantized_linear1 -> dequant
hidden1 -> quant -/    \-> hidden1 -> quantized_linear2 -> dequant
```

Note that we choose to insert DeQuantStubs here instead of observers
because these will ultimately be replaced by "dequantize" nodes. This
matches the general custom module behavior, where output observers
are replaced only with "dequantize" nodes (as opposed to the normal
"quantize-dequantize" pair), since custom module outputs are assumed
to already be quantized. Using DeQuantStubs instead of observers also
simplifies the "dequantize" insertion logic. In the future, we should use
DeQuantStubs in place of output observers for custom modules in general.

**Test plan:**
python test/test_quantization.py TestQuantizeFx.test_static_lstm
python test/test_quantization.py
TestQuantizeFx.test_static_lstm_consume_tuple

**Reviewers:** jerryzh168, vkuzo

**Subscribers:** jerryzh168, vkuzo

ghstack-source-id: adc35fcc921da69f14de316a57741cc97c7fd4a2
Pull Request resolved: #85068
**Summary:** This commit enables the custom module LSTM path for
FX graph mode static quantization. This has the same flow as eager
mode, which was already previously supported:

```
     torch.nn.LSTM
           | (prepare_fx)
           v
torch.ao.nn.quantizable.LSTM
           | (convert_fx)
           v
 torch.ao.nn.quantized.LSTM
```

The main reason why custom module LSTM is not supported in FX
graph mode quantization today is because its inputs and outputs
are nested tuples, and existing constructs such as observers,
"quantize" nodes, and "dequantize" nodes do not understand how
to handle complex structures.

Note that the approach taken in this commit is only intended to
be a short-term solution highly tailored to the input and output
formats of custom module LSTM. In the future, for the longer-term
solution, we should design a more general QConfig that allows users
to specify complex input and output formats, and enable FX graph
mode quantization to understand arbitrary nested structures and
automatically infer how to transform the graph accordingly.

**Context:**

Today, in FX graph mode static quantization, custom modules are
assumed to have quantized inputs and quantized outputs, with the
exact dtypes derived from the associated QConfig (default quint8).
Since custom modules are currently not handled through the reference
model flow, their observer replacement logic are a little different
from normal operators:

```
# (1) Original model
input -> custom_module -> output

# (2) Observed model (after prepare)
input -> obs0 -> custom_module -> obs1 -> output

# (3) Quantized model (after convert)
input -> quant -> quantized_custom_module -> dequant -> output
```

In the last step, input observers are replaced with "quantize"
and output observers are replaced with "dequantize", in contrast
to other non-custom-module patterns where observers are replaced
with "quantize-dequantize" pairs instead. Note that, conceptually,
the output observer `obs1` is really just a DeQuantStub, since no
observation is actually needed.

**Custom module LSTM:**

The reason why custom module LSTM cannot be handled in the same
way is because, unlike other custom modules, its inputs and outputs
are nested tuples instead of single tensors. This is how the existing
custom module code would try to handle LSTMs:

```
# (1) Original model
# input format: (input, (hidden0, hidden1))
# output format:  (output, (hidden0, hidden1))
 input -> lstm -> output
hidden0 -/    \-> hidden0
hidden1 -/    \-> hidden1

# (2) Observed model (after prepare)
 input -> obs0 -> lstm -> obs1  # fails
        hidden0 -/  # missing observer
        hidden1 -/  # missing observer
```

However, this fails today because 1) we assume there is only one input
to the custom module, and so we never end up quantizing `hidden0` and
`hidden1`, and 2) the output observer `obs1` is fed a tuple, which it
does not understand how to handle.

**Short-term fix:**

This commit addresses the above by specifically handling the input
and output structures used by custom module LSTM. For the inputs,
we manually insert observers for `hidden0` and `hidden1` to ensure
all input tensors are quantized.

For the outputs, we split the tuple into its internal nodes, attach
a DeQuantStub to each node, and recombine these DeQuantStubs
according to the original structure. Finally, we must also reroute
consumers of the original LSTM tuple (and its internal nodes, e.g.
`lstm[0]`) to these DeQuantStubs:

```
# (1) Original model
 input -> lstm -> output -> linear0
hidden0 -/    \-> hidden0 -> linear1
hidden1 -/    \-> hidden1 -> linear2

# (2) Observed model (after prepare)
 input -> obs0 -> lstm -> output -> dqstub -> linear0 -> obs3
hidden0 -> obs1 -/    \-> hidden0 -> dqstub -> linear1 -> obs4
hidden1 -> obs2 -/    \-> hidden1 -> dqstub -> linear2 -> obs5

# (3) Reference model (after convert)
 input -> quant -> qlstm -> output -> dequant -> linear0 -> quant -> dequant
hidden0 -> quant -/    \-> hidden0 -> dequant -> linear1 -> quant -> dequant
hidden1 -> quant -/    \-> hidden1 -> dequant -> linear2 -> quant -> dequant

# (4) Quantized model (after lowering)
 input -> quant -> qlstm -> output -> quantized_linear0 -> dequant
hidden0 -> quant -/    \-> hidden0 -> quantized_linear1 -> dequant
hidden1 -> quant -/    \-> hidden1 -> quantized_linear2 -> dequant
```

Note that we choose to insert DeQuantStubs here instead of observers
because these will ultimately be replaced by "dequantize" nodes. This
matches the general custom module behavior, where output observers
are replaced only with "dequantize" nodes (as opposed to the normal
"quantize-dequantize" pair), since custom module outputs are assumed
to already be quantized. Using DeQuantStubs instead of observers also
simplifies the "dequantize" insertion logic. In the future, we should use
DeQuantStubs in place of output observers for custom modules in general.

**Test plan:**
python test/test_quantization.py TestQuantizeFx.test_static_lstm
python test/test_quantization.py
TestQuantizeFx.test_static_lstm_consume_tuple

**Reviewers:** jerryzh168, vkuzo

**Subscribers:** jerryzh168, vkuzo

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 21, 2022
**Summary:** This commit enables the custom module LSTM path for
FX graph mode static quantization. This has the same flow as eager
mode, which was already previously supported:

```
     torch.nn.LSTM
           | (prepare_fx)
           v
torch.ao.nn.quantizable.LSTM
           | (convert_fx)
           v
 torch.ao.nn.quantized.LSTM
```

The main reason why custom module LSTM is not supported in FX
graph mode quantization today is because its inputs and outputs
are nested tuples, and existing constructs such as observers,
"quantize" nodes, and "dequantize" nodes do not understand how
to handle complex structures.

Note that the approach taken in this commit is only intended to
be a short-term solution highly tailored to the input and output
formats of custom module LSTM. In the future, for the longer-term
solution, we should design a more general QConfig that allows users
to specify complex input and output formats, and enable FX graph
mode quantization to understand arbitrary nested structures and
automatically infer how to transform the graph accordingly.

**Context:**

Today, in FX graph mode static quantization, custom modules are
assumed to have quantized inputs and quantized outputs, with the
exact dtypes derived from the associated QConfig (default quint8).
Since custom modules are currently not handled through the reference
model flow, their observer replacement logic are a little different
from normal operators:

```
\# (1) Original model
input -> custom_module -> output

\# (2) Observed model (after prepare)
input -> obs0 -> custom_module -> obs1 -> output

\# (3) Quantized model (after convert)
input -> quant -> quantized_custom_module -> dequant -> output
```

In the last step, input observers are replaced with "quantize"
and output observers are replaced with "dequantize", in contrast
to other non-custom-module patterns where observers are replaced
with "quantize-dequantize" pairs instead. Note that, conceptually,
the output observer `obs1` is really just a DeQuantStub, since no
observation is actually needed.

**Custom module LSTM:**

The reason why custom module LSTM cannot be handled in the same
way is because, unlike other custom modules, its inputs and outputs
are nested tuples instead of single tensors. This is how the existing
custom module code would try to handle LSTMs:

```
\# (1) Original model
\# input format: (input, (hidden0, hidden1))
\# output format:  (output, (hidden0, hidden1))
 input -> lstm -> output
hidden0 -/    \-> hidden0
hidden1 -/    \-> hidden1

\# (2) Observed model (after prepare)
 input -> obs0 -> lstm -> obs1  # fails
        hidden0 -/  # missing observer
        hidden1 -/  # missing observer
```

However, this fails today because 1) we assume there is only one input
to the custom module, and so we never end up quantizing `hidden0` and
`hidden1`, and 2) the output observer `obs1` is fed a tuple, which it
does not understand how to handle.

**Short-term fix:**

This commit addresses the above by specifically handling the input
and output structures used by custom module LSTM. For the inputs,
we manually insert observers for `hidden0` and `hidden1` to ensure
all input tensors are quantized.

For the outputs, we split the tuple into its internal nodes, attach
a DeQuantStub to each node, and recombine these DeQuantStubs
according to the original structure. Finally, we must also reroute
consumers of the original LSTM tuple (and its internal nodes, e.g.
`lstm[0]`) to these DeQuantStubs:

```
\# (1) Original model
 input -> lstm -> output -> linear0
hidden0 -/    \-> hidden0 -> linear1
hidden1 -/    \-> hidden1 -> linear2

\# (2) Observed model (after prepare)
 input -> obs0 -> lstm -> output -> dqstub -> linear0 -> obs3
hidden0 -> obs1 -/    \-> hidden0 -> dqstub -> linear1 -> obs4
hidden1 -> obs2 -/    \-> hidden1 -> dqstub -> linear2 -> obs5

\# (3) Reference model (after convert)
 input -> quant -> qlstm -> output -> dequant -> linear0 -> quant -> dequant
hidden0 -> quant -/    \-> hidden0 -> dequant -> linear1 -> quant -> dequant
hidden1 -> quant -/    \-> hidden1 -> dequant -> linear2 -> quant -> dequant

\# (4) Quantized model (after lowering)
 input -> quant -> qlstm -> output -> quantized_linear0 -> dequant
hidden0 -> quant -/    \-> hidden0 -> quantized_linear1 -> dequant
hidden1 -> quant -/    \-> hidden1 -> quantized_linear2 -> dequant
```

Note that we choose to insert DeQuantStubs here instead of observers
because these will ultimately be replaced by "dequantize" nodes. This
matches the general custom module behavior, where output observers
are replaced only with "dequantize" nodes (as opposed to the normal
"quantize-dequantize" pair), since custom module outputs are assumed
to already be quantized. Using DeQuantStubs instead of observers also
simplifies the "dequantize" insertion logic. In the future, we should use
DeQuantStubs in place of output observers for custom modules in general.

**Test plan:**
python test/test_quantization.py TestQuantizeFx.test_static_lstm
python test/test_quantization.py
TestQuantizeFx.test_static_lstm_consume_tuple

**Reviewers:** jerryzh168, vkuzo

**Subscribers:** jerryzh168, vkuzo

ghstack-source-id: 39c4989fa6425c59d23c3704d3ff0becb75069ac
Pull Request resolved: #85068
@andrewor14
Copy link
Contributor Author

can we write down all the special cases we make for lstm in prepare and convert step? this will help us to evaluate how well this is aligned with the long term plan

Sure. In prepare, we make two special cases:

  1. When inserting output observers, we instead break the LSTM output tuple into its internal nodes, insert a DeQuantStub after each node, and recombine the tuple according to the original format (code). This diverges from, but is conceptually the same as, how we insert "special" output observers that will be converted to "dequantize" nodes (during convert) today for other custom modules. In the future, we should just insert DeQuantStubs instead of these "special" observers for custom modules in general.

  2. When inserting input observers for a node, we check whether it is a consumer of LSTM by traversing up the subgraph created in (1) (code). Once we identified what the previous node is, we reuse the existing code to decide whether or not to insert the input observer based on the output dtype of the previous node.

As for convert:

  1. When converting the custom module, we change quantize - dequantize - custom_module to quantize - custom_module for all inputs of custom module LSTM, including the hidden inputs (code). This is consistent with how we do the same for custom module inputs today, except it handles more than one input.

  2. We replace DeQuantStubs with "dequantize" nodes. This is not really the special case as we plan to do this for all custom modules in the future.

**Summary:** This commit enables the custom module LSTM path for
FX graph mode static quantization. This has the same flow as eager
mode, which was already previously supported:

```
     torch.nn.LSTM
           | (prepare_fx)
           v
torch.ao.nn.quantizable.LSTM
           | (convert_fx)
           v
 torch.ao.nn.quantized.LSTM
```

The main reason why custom module LSTM is not supported in FX
graph mode quantization today is because its inputs and outputs
are nested tuples, and existing constructs such as observers,
"quantize" nodes, and "dequantize" nodes do not understand how
to handle complex structures.

Note that the approach taken in this commit is only intended to
be a short-term solution highly tailored to the input and output
formats of custom module LSTM. In the future, for the longer-term
solution, we should design a more general QConfig that allows users
to specify complex input and output formats, and enable FX graph
mode quantization to understand arbitrary nested structures and
automatically infer how to transform the graph accordingly.

**Context:**

Today, in FX graph mode static quantization, custom modules are
assumed to have quantized inputs and quantized outputs, with the
exact dtypes derived from the associated QConfig (default quint8).
Since custom modules are currently not handled through the reference
model flow, their observer replacement logic are a little different
from normal operators:

```
# (1) Original model
input -> custom_module -> output

# (2) Observed model (after prepare)
input -> obs0 -> custom_module -> obs1 -> output

# (3) Quantized model (after convert)
input -> quant -> quantized_custom_module -> dequant -> output
```

In the last step, input observers are replaced with "quantize"
and output observers are replaced with "dequantize", in contrast
to other non-custom-module patterns where observers are replaced
with "quantize-dequantize" pairs instead. Note that, conceptually,
the output observer `obs1` is really just a DeQuantStub, since no
observation is actually needed.

**Custom module LSTM:**

The reason why custom module LSTM cannot be handled in the same
way is because, unlike other custom modules, its inputs and outputs
are nested tuples instead of single tensors. This is how the existing
custom module code would try to handle LSTMs:

```
# (1) Original model
# input format: (input, (hidden0, hidden1))
# output format:  (output, (hidden0, hidden1))
 input -> lstm -> output
hidden0 -/    \-> hidden0
hidden1 -/    \-> hidden1

# (2) Observed model (after prepare)
 input -> obs0 -> lstm -> obs1  # fails
        hidden0 -/  # missing observer
        hidden1 -/  # missing observer
```

However, this fails today because 1) we assume there is only one input
to the custom module, and so we never end up quantizing `hidden0` and
`hidden1`, and 2) the output observer `obs1` is fed a tuple, which it
does not understand how to handle.

**Short-term fix:**

This commit addresses the above by specifically handling the input
and output structures used by custom module LSTM. For the inputs,
we manually insert observers for `hidden0` and `hidden1` to ensure
all input tensors are quantized.

For the outputs, we split the tuple into its internal nodes, attach
a DeQuantStub to each node, and recombine these DeQuantStubs
according to the original structure. Finally, we must also reroute
consumers of the original LSTM tuple (and its internal nodes, e.g.
`lstm[0]`) to these DeQuantStubs:

```
# (1) Original model
 input -> lstm -> output -> linear0
hidden0 -/    \-> hidden0 -> linear1
hidden1 -/    \-> hidden1 -> linear2

# (2) Observed model (after prepare)
 input -> obs0 -> lstm -> output -> dqstub -> linear0 -> obs3
hidden0 -> obs1 -/    \-> hidden0 -> dqstub -> linear1 -> obs4
hidden1 -> obs2 -/    \-> hidden1 -> dqstub -> linear2 -> obs5

# (3) Reference model (after convert)
 input -> quant -> qlstm -> output -> dequant -> linear0 -> quant -> dequant
hidden0 -> quant -/    \-> hidden0 -> dequant -> linear1 -> quant -> dequant
hidden1 -> quant -/    \-> hidden1 -> dequant -> linear2 -> quant -> dequant

# (4) Quantized model (after lowering)
 input -> quant -> qlstm -> output -> quantized_linear0 -> dequant
hidden0 -> quant -/    \-> hidden0 -> quantized_linear1 -> dequant
hidden1 -> quant -/    \-> hidden1 -> quantized_linear2 -> dequant
```

Note that we choose to insert DeQuantStubs here instead of observers
because these will ultimately be replaced by "dequantize" nodes. This
matches the general custom module behavior, where output observers
are replaced only with "dequantize" nodes (as opposed to the normal
"quantize-dequantize" pair), since custom module outputs are assumed
to already be quantized. Using DeQuantStubs instead of observers also
simplifies the "dequantize" insertion logic. In the future, we should use
DeQuantStubs in place of output observers for custom modules in general.

**Test plan:**
python test/test_quantization.py TestQuantizeFx.test_static_lstm
python test/test_quantization.py
TestQuantizeFx.test_static_lstm_consume_tuple

**Reviewers:** jerryzh168, vkuzo

**Subscribers:** jerryzh168, vkuzo

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 21, 2022
**Summary:** This commit enables the custom module LSTM path for
FX graph mode static quantization. This has the same flow as eager
mode, which was already previously supported:

```
     torch.nn.LSTM
           | (prepare_fx)
           v
torch.ao.nn.quantizable.LSTM
           | (convert_fx)
           v
 torch.ao.nn.quantized.LSTM
```

The main reason why custom module LSTM is not supported in FX
graph mode quantization today is because its inputs and outputs
are nested tuples, and existing constructs such as observers,
"quantize" nodes, and "dequantize" nodes do not understand how
to handle complex structures.

Note that the approach taken in this commit is only intended to
be a short-term solution highly tailored to the input and output
formats of custom module LSTM. In the future, for the longer-term
solution, we should design a more general QConfig that allows users
to specify complex input and output formats, and enable FX graph
mode quantization to understand arbitrary nested structures and
automatically infer how to transform the graph accordingly.

**Context:**

Today, in FX graph mode static quantization, custom modules are
assumed to have quantized inputs and quantized outputs, with the
exact dtypes derived from the associated QConfig (default quint8).
Since custom modules are currently not handled through the reference
model flow, their observer replacement logic are a little different
from normal operators:

```
\# (1) Original model
input -> custom_module -> output

\# (2) Observed model (after prepare)
input -> obs0 -> custom_module -> obs1 -> output

\# (3) Quantized model (after convert)
input -> quant -> quantized_custom_module -> dequant -> output
```

In the last step, input observers are replaced with "quantize"
and output observers are replaced with "dequantize", in contrast
to other non-custom-module patterns where observers are replaced
with "quantize-dequantize" pairs instead. Note that, conceptually,
the output observer `obs1` is really just a DeQuantStub, since no
observation is actually needed.

**Custom module LSTM:**

The reason why custom module LSTM cannot be handled in the same
way is because, unlike other custom modules, its inputs and outputs
are nested tuples instead of single tensors. This is how the existing
custom module code would try to handle LSTMs:

```
\# (1) Original model
\# input format: (input, (hidden0, hidden1))
\# output format:  (output, (hidden0, hidden1))
 input -> lstm -> output
hidden0 -/    \-> hidden0
hidden1 -/    \-> hidden1

\# (2) Observed model (after prepare)
 input -> obs0 -> lstm -> obs1  # fails
        hidden0 -/  # missing observer
        hidden1 -/  # missing observer
```

However, this fails today because 1) we assume there is only one input
to the custom module, and so we never end up quantizing `hidden0` and
`hidden1`, and 2) the output observer `obs1` is fed a tuple, which it
does not understand how to handle.

**Short-term fix:**

This commit addresses the above by specifically handling the input
and output structures used by custom module LSTM. For the inputs,
we manually insert observers for `hidden0` and `hidden1` to ensure
all input tensors are quantized.

For the outputs, we split the tuple into its internal nodes, attach
a DeQuantStub to each node, and recombine these DeQuantStubs
according to the original structure. Finally, we must also reroute
consumers of the original LSTM tuple (and its internal nodes, e.g.
`lstm[0]`) to these DeQuantStubs:

```
\# (1) Original model
 input -> lstm -> output -> linear0
hidden0 -/    \-> hidden0 -> linear1
hidden1 -/    \-> hidden1 -> linear2

\# (2) Observed model (after prepare)
 input -> obs0 -> lstm -> output -> dqstub -> linear0 -> obs3
hidden0 -> obs1 -/    \-> hidden0 -> dqstub -> linear1 -> obs4
hidden1 -> obs2 -/    \-> hidden1 -> dqstub -> linear2 -> obs5

\# (3) Reference model (after convert)
 input -> quant -> qlstm -> output -> dequant -> linear0 -> quant -> dequant
hidden0 -> quant -/    \-> hidden0 -> dequant -> linear1 -> quant -> dequant
hidden1 -> quant -/    \-> hidden1 -> dequant -> linear2 -> quant -> dequant

\# (4) Quantized model (after lowering)
 input -> quant -> qlstm -> output -> quantized_linear0 -> dequant
hidden0 -> quant -/    \-> hidden0 -> quantized_linear1 -> dequant
hidden1 -> quant -/    \-> hidden1 -> quantized_linear2 -> dequant
```

Note that we choose to insert DeQuantStubs here instead of observers
because these will ultimately be replaced by "dequantize" nodes. This
matches the general custom module behavior, where output observers
are replaced only with "dequantize" nodes (as opposed to the normal
"quantize-dequantize" pair), since custom module outputs are assumed
to already be quantized. Using DeQuantStubs instead of observers also
simplifies the "dequantize" insertion logic. In the future, we should use
DeQuantStubs in place of output observers for custom modules in general.

**Test plan:**

python test/test_quantization.py TestQuantizeFx.test_static_lstm

python test/test_quantization.py
TestQuantizeFx.test_static_lstm_consume_tuple

python test/test_quantization.py
TestQuantizeFx.test_reroute_tuple_getitem_patterns

**Reviewers:** jerryzh168, vkuzo

**Subscribers:** jerryzh168, vkuzo

ghstack-source-id: d5a67e565fca7f62d2d22c16b6ec662e7079351c
Pull Request resolved: #85068
return matched_node
return None

def _reroute_tuple_getitem_pattern(graph: Graph):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @SherlockNoMad do we want to support something similar in some default fx optimization pass?

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good, thanks for the hard work!

@jerryzh168
Copy link
Contributor

jerryzh168 commented Sep 22, 2022

can we write down all the special cases we make for lstm in prepare and convert step? this will help us to evaluate how well this is aligned with the long term plan

Sure. In prepare, we make two special cases:

  1. When inserting output observers, we instead break the LSTM output tuple into its internal nodes, insert a DeQuantStub after each node, and recombine the tuple according to the original format (code). This diverges from, but is conceptually the same as, how we insert "special" output observers that will be converted to "dequantize" nodes (during convert) today for other custom modules. In the future, we should just insert DeQuantStubs instead of these "special" observers for custom modules in general.
  2. When inserting input observers for a node, we check whether it is a consumer of LSTM by traversing up the subgraph created in (1) (code). Once we identified what the previous node is, we reuse the existing code to decide whether or not to insert the input observer based on the output dtype of the previous node.

As for convert:

  1. When converting the custom module, we change quantize - dequantize - custom_module to quantize - custom_module for all inputs of custom module LSTM, including the hidden inputs (code). This is consistent with how we do the same for custom module inputs today, except it handles more than one input.
  2. We replace DeQuantStubs with "dequantize" nodes. This is not really the special case as we plan to do this for all custom modules in the future.

thanks, please add this to the Summary as well so that we can use this to fix the hacks in the future

**Summary:** This commit enables the custom module LSTM path for
FX graph mode static quantization. This has the same flow as eager
mode, which was already previously supported:

```
     torch.nn.LSTM
           | (prepare_fx)
           v
torch.ao.nn.quantizable.LSTM
           | (convert_fx)
           v
 torch.ao.nn.quantized.LSTM
```

The main reason why custom module LSTM is not supported in FX
graph mode quantization today is because its inputs and outputs
are nested tuples, and existing constructs such as observers,
"quantize" nodes, and "dequantize" nodes do not understand how
to handle complex structures.

Note that the approach taken in this commit is only intended to
be a short-term solution highly tailored to the input and output
formats of custom module LSTM. In the future, for the longer-term
solution, we should design a more general QConfig that allows users
to specify complex input and output formats, and enable FX graph
mode quantization to understand arbitrary nested structures and
automatically infer how to transform the graph accordingly.

**Context:**

Today, in FX graph mode static quantization, custom modules are
assumed to have quantized inputs and quantized outputs, with the
exact dtypes derived from the associated QConfig (default quint8).
Since custom modules are currently not handled through the reference
model flow, their observer replacement logic are a little different
from normal operators:

```
# (1) Original model
input -> custom_module -> output

# (2) Observed model (after prepare)
input -> obs0 -> custom_module -> obs1 -> output

# (3) Quantized model (after convert)
input -> quant -> quantized_custom_module -> dequant -> output
```

In the last step, input observers are replaced with "quantize"
and output observers are replaced with "dequantize", in contrast
to other non-custom-module patterns where observers are replaced
with "quantize-dequantize" pairs instead. Note that, conceptually,
the output observer `obs1` is really just a DeQuantStub, since no
observation is actually needed.

**Custom module LSTM:**

The reason why custom module LSTM cannot be handled in the same
way is because, unlike other custom modules, its inputs and outputs
are nested tuples instead of single tensors. This is how the existing
custom module code would try to handle LSTMs:

```
# (1) Original model
# input format: (input, (hidden0, hidden1))
# output format:  (output, (hidden0, hidden1))
 input -> lstm -> output
hidden0 -/    \-> hidden0
hidden1 -/    \-> hidden1

# (2) Observed model (after prepare)
 input -> obs0 -> lstm -> obs1  # fails
        hidden0 -/  # missing observer
        hidden1 -/  # missing observer
```

However, this fails today because 1) we assume there is only one input
to the custom module, and so we never end up quantizing `hidden0` and
`hidden1`, and 2) the output observer `obs1` is fed a tuple, which it
does not understand how to handle.

**Short-term fix:**

This commit addresses the above by specifically handling the input
and output structures used by custom module LSTM. For the inputs,
we manually insert observers for `hidden0` and `hidden1` to ensure
all input tensors are quantized.

For the outputs, we split the tuple into its internal nodes, attach
a DeQuantStub to each node, and recombine these DeQuantStubs
according to the original structure. Finally, we must also reroute
consumers of the original LSTM tuple (and its internal nodes, e.g.
`lstm[0]`) to these DeQuantStubs:

```
# (1) Original model
 input -> lstm -> output -> linear0
hidden0 -/    \-> hidden0 -> linear1
hidden1 -/    \-> hidden1 -> linear2

# (2) Observed model (after prepare)
 input -> obs0 -> lstm -> output -> dqstub -> linear0 -> obs3
hidden0 -> obs1 -/    \-> hidden0 -> dqstub -> linear1 -> obs4
hidden1 -> obs2 -/    \-> hidden1 -> dqstub -> linear2 -> obs5

# (3) Reference model (after convert)
 input -> quant -> qlstm -> output -> dequant -> linear0 -> quant -> dequant
hidden0 -> quant -/    \-> hidden0 -> dequant -> linear1 -> quant -> dequant
hidden1 -> quant -/    \-> hidden1 -> dequant -> linear2 -> quant -> dequant

# (4) Quantized model (after lowering)
 input -> quant -> qlstm -> output -> quantized_linear0 -> dequant
hidden0 -> quant -/    \-> hidden0 -> quantized_linear1 -> dequant
hidden1 -> quant -/    \-> hidden1 -> quantized_linear2 -> dequant
```

Note that we choose to insert DeQuantStubs here instead of observers
because these will ultimately be replaced by "dequantize" nodes. This
matches the general custom module behavior, where output observers
are replaced only with "dequantize" nodes (as opposed to the normal
"quantize-dequantize" pair), since custom module outputs are assumed
to already be quantized. Using DeQuantStubs instead of observers also
simplifies the "dequantize" insertion logic. In the future, we should use
DeQuantStubs in place of output observers for custom modules in general.

**Test plan:**
python test/test_quantization.py TestQuantizeFx.test_static_lstm
python test/test_quantization.py
TestQuantizeFx.test_static_lstm_consume_tuple

**Reviewers:** jerryzh168, vkuzo

**Subscribers:** jerryzh168, vkuzo

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 22, 2022
**Summary:** This commit enables the custom module LSTM path for
FX graph mode static quantization. This has the same flow as eager
mode, which was already previously supported:

```
     torch.nn.LSTM
           | (prepare_fx)
           v
torch.ao.nn.quantizable.LSTM
           | (convert_fx)
           v
 torch.ao.nn.quantized.LSTM
```

The main reason why custom module LSTM is not supported in FX
graph mode quantization today is because its inputs and outputs
are nested tuples, and existing constructs such as observers,
"quantize" nodes, and "dequantize" nodes do not understand how
to handle complex structures.

Note that the approach taken in this commit is only intended to
be a short-term solution highly tailored to the input and output
formats of custom module LSTM. In the future, for the longer-term
solution, we should design a more general QConfig that allows users
to specify complex input and output formats, and enable FX graph
mode quantization to understand arbitrary nested structures and
automatically infer how to transform the graph accordingly.

**Context:**

Today, in FX graph mode static quantization, custom modules are
assumed to have quantized inputs and quantized outputs, with the
exact dtypes derived from the associated QConfig (default quint8).
Since custom modules are currently not handled through the reference
model flow, their observer replacement logic are a little different
from normal operators:

```
\# (1) Original model
input -> custom_module -> output

\# (2) Observed model (after prepare)
input -> obs0 -> custom_module -> obs1 -> output

\# (3) Quantized model (after convert)
input -> quant -> quantized_custom_module -> dequant -> output
```

In the last step, input observers are replaced with "quantize"
and output observers are replaced with "dequantize", in contrast
to other non-custom-module patterns where observers are replaced
with "quantize-dequantize" pairs instead. Note that, conceptually,
the output observer `obs1` is really just a DeQuantStub, since no
observation is actually needed.

**Custom module LSTM:**

The reason why custom module LSTM cannot be handled in the same
way is because, unlike other custom modules, its inputs and outputs
are nested tuples instead of single tensors. This is how the existing
custom module code would try to handle LSTMs:

```
\# (1) Original model
\# input format: (input, (hidden0, hidden1))
\# output format:  (output, (hidden0, hidden1))
 input -> lstm -> output
hidden0 -/    \-> hidden0
hidden1 -/    \-> hidden1

\# (2) Observed model (after prepare)
 input -> obs0 -> lstm -> obs1  # fails
        hidden0 -/  # missing observer
        hidden1 -/  # missing observer
```

However, this fails today because 1) we assume there is only one input
to the custom module, and so we never end up quantizing `hidden0` and
`hidden1`, and 2) the output observer `obs1` is fed a tuple, which it
does not understand how to handle.

**Short-term fix:**

This commit addresses the above by specifically handling the input
and output structures used by custom module LSTM. For the inputs,
we manually insert observers for `hidden0` and `hidden1` to ensure
all input tensors are quantized.

For the outputs, we split the tuple into its internal nodes, attach
a DeQuantStub to each node, and recombine these DeQuantStubs
according to the original structure. Finally, we must also reroute
consumers of the original LSTM tuple (and its internal nodes, e.g.
`lstm[0]`) to these DeQuantStubs:

```
\# (1) Original model
 input -> lstm -> output -> linear0
hidden0 -/    \-> hidden0 -> linear1
hidden1 -/    \-> hidden1 -> linear2

\# (2) Observed model (after prepare)
 input -> obs0 -> lstm -> output -> dqstub -> linear0 -> obs3
hidden0 -> obs1 -/    \-> hidden0 -> dqstub -> linear1 -> obs4
hidden1 -> obs2 -/    \-> hidden1 -> dqstub -> linear2 -> obs5

\# (3) Reference model (after convert)
 input -> quant -> qlstm -> output -> dequant -> linear0 -> quant -> dequant
hidden0 -> quant -/    \-> hidden0 -> dequant -> linear1 -> quant -> dequant
hidden1 -> quant -/    \-> hidden1 -> dequant -> linear2 -> quant -> dequant

\# (4) Quantized model (after lowering)
 input -> quant -> qlstm -> output -> quantized_linear0 -> dequant
hidden0 -> quant -/    \-> hidden0 -> quantized_linear1 -> dequant
hidden1 -> quant -/    \-> hidden1 -> quantized_linear2 -> dequant
```

Note that we choose to insert DeQuantStubs here instead of observers
because these will ultimately be replaced by "dequantize" nodes. This
matches the general custom module behavior, where output observers
are replaced only with "dequantize" nodes (as opposed to the normal
"quantize-dequantize" pair), since custom module outputs are assumed
to already be quantized. Using DeQuantStubs instead of observers also
simplifies the "dequantize" insertion logic. In the future, we should use
DeQuantStubs in place of output observers for custom modules in general.

**Implementation:**

In prepare, we make two special cases:

(1) When inserting output observers, we instead break the LSTM output
tuple into its internal nodes, insert a DeQuantStub after each node,
and recombine the tuple according to the original format
(in `_insert_dequant_stubs_for_custom_module_lstm_output`). This
diverges from, but is conceptually the same as, how we insert "special"
output observers that will be converted to "dequantize" nodes (during
convert) today for other custom modules. In the future, we should just
insert DeQuantStubs instead of these "special" observers for custom
modules in general.

(2) When inserting input observers for a node, we check whether it is
a consumer of LSTM by traversing up the subgraph created in (1)
(see `_maybe_get_custom_module_lstm_from_node_arg`).
Once we identified what the previous node is, we reuse the existing code
in `maybe_insert_input_observer_for_arg_or_kwarg` to decide whether or
not to insert the input observer based on the output dtype of the
previous node.

As for convert:

(1) When converting the custom module, we change
`quantize - dequantize - custom_module` to `quantize - custom_module`
for all inputs of custom module LSTM, including the hidden inputs (in
`convert_custom_module`). This is consistent with how we do the same
for custom module inputs today, except it handles more than one input.

(2) We replace DeQuantStubs with "dequantize" nodes. This is not
really the special case as we plan to do this for all custom modules
in the future.

**Test plan:**

python test/test_quantization.py TestQuantizeFx.test_static_lstm

python test/test_quantization.py
TestQuantizeFx.test_static_lstm_consume_tuple

python test/test_quantization.py
TestQuantizeFx.test_reroute_tuple_getitem_patterns

**Reviewers:** jerryzh168, vkuzo

**Subscribers:** jerryzh168, vkuzo

ghstack-source-id: 06f45764ec461dcd443617cd43964b352fa490d2
Pull Request resolved: #85068
@andrewor14
Copy link
Contributor Author

Ok, I'm merging this. Thanks for all the feedback!

@andrewor14
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here.
The merge job was triggered without a flag. This means that your change will be merged once all checks on your PR have passed (ETA: 0-4 Hours). If this is not the intended behavior, feel free to use some of the other merge options in the wiki.
Please reach out to the PyTorch DevX Team with feedback or questions!

@github-actions
Copy link

Hey @andrewor14.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@andrewor14 andrewor14 added the topic: improvements topic category label Sep 23, 2022
@facebook-github-bot facebook-github-bot deleted the gh/andrewor14/27/head branch September 26, 2022 14:20
mehtanirav pushed a commit that referenced this pull request Oct 4, 2022
**Summary:** This commit enables the custom module LSTM path for
FX graph mode static quantization. This has the same flow as eager
mode, which was already previously supported:

```
     torch.nn.LSTM
           | (prepare_fx)
           v
torch.ao.nn.quantizable.LSTM
           | (convert_fx)
           v
 torch.ao.nn.quantized.LSTM
```

The main reason why custom module LSTM is not supported in FX
graph mode quantization today is because its inputs and outputs
are nested tuples, and existing constructs such as observers,
"quantize" nodes, and "dequantize" nodes do not understand how
to handle complex structures.

Note that the approach taken in this commit is only intended to
be a short-term solution highly tailored to the input and output
formats of custom module LSTM. In the future, for the longer-term
solution, we should design a more general QConfig that allows users
to specify complex input and output formats, and enable FX graph
mode quantization to understand arbitrary nested structures and
automatically infer how to transform the graph accordingly.

**Context:**

Today, in FX graph mode static quantization, custom modules are
assumed to have quantized inputs and quantized outputs, with the
exact dtypes derived from the associated QConfig (default quint8).
Since custom modules are currently not handled through the reference
model flow, their observer replacement logic are a little different
from normal operators:

```
# (1) Original model
input -> custom_module -> output

# (2) Observed model (after prepare)
input -> obs0 -> custom_module -> obs1 -> output

# (3) Quantized model (after convert)
input -> quant -> quantized_custom_module -> dequant -> output
```

In the last step, input observers are replaced with "quantize"
and output observers are replaced with "dequantize", in contrast
to other non-custom-module patterns where observers are replaced
with "quantize-dequantize" pairs instead. Note that, conceptually,
the output observer `obs1` is really just a DeQuantStub, since no
observation is actually needed.

**Custom module LSTM:**

The reason why custom module LSTM cannot be handled in the same
way is because, unlike other custom modules, its inputs and outputs
are nested tuples instead of single tensors. This is how the existing
custom module code would try to handle LSTMs:

```
# (1) Original model
# input format: (input, (hidden0, hidden1))
# output format:  (output, (hidden0, hidden1))
 input -> lstm -> output
hidden0 -/    \-> hidden0
hidden1 -/    \-> hidden1

# (2) Observed model (after prepare)
 input -> obs0 -> lstm -> obs1  # fails
        hidden0 -/  # missing observer
        hidden1 -/  # missing observer
```

However, this fails today because 1) we assume there is only one input
to the custom module, and so we never end up quantizing `hidden0` and
`hidden1`, and 2) the output observer `obs1` is fed a tuple, which it
does not understand how to handle.

**Short-term fix:**

This commit addresses the above by specifically handling the input
and output structures used by custom module LSTM. For the inputs,
we manually insert observers for `hidden0` and `hidden1` to ensure
all input tensors are quantized.

For the outputs, we split the tuple into its internal nodes, attach
a DeQuantStub to each node, and recombine these DeQuantStubs
according to the original structure. Finally, we must also reroute
consumers of the original LSTM tuple (and its internal nodes, e.g.
`lstm[0]`) to these DeQuantStubs:

```
# (1) Original model
 input -> lstm -> output -> linear0
hidden0 -/    \-> hidden0 -> linear1
hidden1 -/    \-> hidden1 -> linear2

# (2) Observed model (after prepare)
 input -> obs0 -> lstm -> output -> dqstub -> linear0 -> obs3
hidden0 -> obs1 -/    \-> hidden0 -> dqstub -> linear1 -> obs4
hidden1 -> obs2 -/    \-> hidden1 -> dqstub -> linear2 -> obs5

# (3) Reference model (after convert)
 input -> quant -> qlstm -> output -> dequant -> linear0 -> quant -> dequant
hidden0 -> quant -/    \-> hidden0 -> dequant -> linear1 -> quant -> dequant
hidden1 -> quant -/    \-> hidden1 -> dequant -> linear2 -> quant -> dequant

# (4) Quantized model (after lowering)
 input -> quant -> qlstm -> output -> quantized_linear0 -> dequant
hidden0 -> quant -/    \-> hidden0 -> quantized_linear1 -> dequant
hidden1 -> quant -/    \-> hidden1 -> quantized_linear2 -> dequant
```

Note that we choose to insert DeQuantStubs here instead of observers
because these will ultimately be replaced by "dequantize" nodes. This
matches the general custom module behavior, where output observers
are replaced only with "dequantize" nodes (as opposed to the normal
"quantize-dequantize" pair), since custom module outputs are assumed
to already be quantized. Using DeQuantStubs instead of observers also
simplifies the "dequantize" insertion logic. In the future, we should use
DeQuantStubs in place of output observers for custom modules in general.

**Test plan:**
python test/test_quantization.py TestQuantizeFx.test_static_lstm
python test/test_quantization.py
TestQuantizeFx.test_static_lstm_consume_tuple

**Reviewers:** jerryzh168, vkuzo

**Subscribers:** jerryzh168, vkuzo
Pull Request resolved: #85068
Approved by: https://github.com/jerryzh168
alvgaona pushed a commit to alvgaona/pytorch that referenced this pull request Oct 11, 2022
**Summary:** This commit enables the custom module LSTM path for
FX graph mode static quantization. This has the same flow as eager
mode, which was already previously supported:

```
     torch.nn.LSTM
           | (prepare_fx)
           v
torch.ao.nn.quantizable.LSTM
           | (convert_fx)
           v
 torch.ao.nn.quantized.LSTM
```

The main reason why custom module LSTM is not supported in FX
graph mode quantization today is because its inputs and outputs
are nested tuples, and existing constructs such as observers,
"quantize" nodes, and "dequantize" nodes do not understand how
to handle complex structures.

Note that the approach taken in this commit is only intended to
be a short-term solution highly tailored to the input and output
formats of custom module LSTM. In the future, for the longer-term
solution, we should design a more general QConfig that allows users
to specify complex input and output formats, and enable FX graph
mode quantization to understand arbitrary nested structures and
automatically infer how to transform the graph accordingly.

**Context:**

Today, in FX graph mode static quantization, custom modules are
assumed to have quantized inputs and quantized outputs, with the
exact dtypes derived from the associated QConfig (default quint8).
Since custom modules are currently not handled through the reference
model flow, their observer replacement logic are a little different
from normal operators:

```
# (1) Original model
input -> custom_module -> output

# (2) Observed model (after prepare)
input -> obs0 -> custom_module -> obs1 -> output

# (3) Quantized model (after convert)
input -> quant -> quantized_custom_module -> dequant -> output
```

In the last step, input observers are replaced with "quantize"
and output observers are replaced with "dequantize", in contrast
to other non-custom-module patterns where observers are replaced
with "quantize-dequantize" pairs instead. Note that, conceptually,
the output observer `obs1` is really just a DeQuantStub, since no
observation is actually needed.

**Custom module LSTM:**

The reason why custom module LSTM cannot be handled in the same
way is because, unlike other custom modules, its inputs and outputs
are nested tuples instead of single tensors. This is how the existing
custom module code would try to handle LSTMs:

```
# (1) Original model
# input format: (input, (hidden0, hidden1))
# output format:  (output, (hidden0, hidden1))
 input -> lstm -> output
hidden0 -/    \-> hidden0
hidden1 -/    \-> hidden1

# (2) Observed model (after prepare)
 input -> obs0 -> lstm -> obs1  # fails
        hidden0 -/  # missing observer
        hidden1 -/  # missing observer
```

However, this fails today because 1) we assume there is only one input
to the custom module, and so we never end up quantizing `hidden0` and
`hidden1`, and 2) the output observer `obs1` is fed a tuple, which it
does not understand how to handle.

**Short-term fix:**

This commit addresses the above by specifically handling the input
and output structures used by custom module LSTM. For the inputs,
we manually insert observers for `hidden0` and `hidden1` to ensure
all input tensors are quantized.

For the outputs, we split the tuple into its internal nodes, attach
a DeQuantStub to each node, and recombine these DeQuantStubs
according to the original structure. Finally, we must also reroute
consumers of the original LSTM tuple (and its internal nodes, e.g.
`lstm[0]`) to these DeQuantStubs:

```
# (1) Original model
 input -> lstm -> output -> linear0
hidden0 -/    \-> hidden0 -> linear1
hidden1 -/    \-> hidden1 -> linear2

# (2) Observed model (after prepare)
 input -> obs0 -> lstm -> output -> dqstub -> linear0 -> obs3
hidden0 -> obs1 -/    \-> hidden0 -> dqstub -> linear1 -> obs4
hidden1 -> obs2 -/    \-> hidden1 -> dqstub -> linear2 -> obs5

# (3) Reference model (after convert)
 input -> quant -> qlstm -> output -> dequant -> linear0 -> quant -> dequant
hidden0 -> quant -/    \-> hidden0 -> dequant -> linear1 -> quant -> dequant
hidden1 -> quant -/    \-> hidden1 -> dequant -> linear2 -> quant -> dequant

# (4) Quantized model (after lowering)
 input -> quant -> qlstm -> output -> quantized_linear0 -> dequant
hidden0 -> quant -/    \-> hidden0 -> quantized_linear1 -> dequant
hidden1 -> quant -/    \-> hidden1 -> quantized_linear2 -> dequant
```

Note that we choose to insert DeQuantStubs here instead of observers
because these will ultimately be replaced by "dequantize" nodes. This
matches the general custom module behavior, where output observers
are replaced only with "dequantize" nodes (as opposed to the normal
"quantize-dequantize" pair), since custom module outputs are assumed
to already be quantized. Using DeQuantStubs instead of observers also
simplifies the "dequantize" insertion logic. In the future, we should use
DeQuantStubs in place of output observers for custom modules in general.

**Test plan:**
python test/test_quantization.py TestQuantizeFx.test_static_lstm
python test/test_quantization.py
TestQuantizeFx.test_static_lstm_consume_tuple

**Reviewers:** jerryzh168, vkuzo

**Subscribers:** jerryzh168, vkuzo
Pull Request resolved: pytorch#85068
Approved by: https://github.com/jerryzh168
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants