Skip to content

[quantization] Fix tracing for dynamic quantized LSTM #29331

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

jamesr66a
Copy link
Collaborator

@jamesr66a jamesr66a commented Nov 6, 2019

Stack from ghstack:

Closes #27954

This fixes the hard-coding of packed parameter values for the dynamic quantized LSTM by orchestrating the following dance:

  1. Each variadic parameter on the module has its own Module. That Module defines the __getstate__ and setstate` method s.t. packed weights are properly re-done on model load.
  2. Each of these modules is wrapped into a torch.nn.ModuleList, s.t. the parameters appear as attributes in the hierarchy. Then, gatherParametersAndBuffers (
    static void gatherParametersAndBuffers(
    ) can see these parameters and create a Value* for them in the traced graph.
  3. In forward, we need to convert from ModuleList -> Module -> Parameter to a simple TensorList of the parameters. We just use a loop here. In tracing, we simply record a ListConstruct with each of the proper parameter values. In scripting, the ModuleList is const, so it can be unrolled into the graph and a subsequent ListConstruct does its business.

The forward of the traced LSTM before and after this change are as follows:

Before

def forward(self,
    input: Tensor,
    argument_2: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
  hx, hx0, = argument_2
  _0, _1, _2 = torch.quantized_lstm(input, [hx, hx0], [CONSTANTS.c0, CONSTANTS.c1], True, 1, 0., True, False, False, dtype=12, use_dynamic=True)
  return (_0, (_1, _2))

After

def forward(self,
    input: Tensor,
    argument_2: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
  _0 = self.cell._all_weight_values
  _1 = getattr(_0, "0").param
  _2 = getattr(_0, "1").param
  hx, hx0, = argument_2
  _3, _4, _5 = torch.quantized_lstm(input, [hx, hx0], [_1, _2], True, 1, 0., True, False, False, dtype=12, use_dynamic=True)
  return (_3, (_4, _5))

Differential Revision: D18374904

@jamesr66a jamesr66a requested a review from apaszke as a code owner November 6, 2019 21:43
jamesr66a pushed a commit that referenced this pull request Nov 6, 2019
ghstack-source-id: 88e30b4
Pull Request resolved: #29331
Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks cleaner to me!

@facebook-github-bot
Copy link
Contributor

@jamesr66a merged this pull request in f17e02f.

@mruberry
Copy link
Collaborator

mruberry commented Nov 7, 2019

Reverted.

@mruberry mruberry reopened this Nov 7, 2019
Closes #27954 

This fixes the hard-coding of packed parameter values for the dynamic quantized LSTM by orchestrating the following dance:

1) Each variadic parameter on the module has its own Module. That Module defines the `__getstate__` and __setstate__` method s.t. packed weights are properly re-done on model load.
2) Each of these modules is wrapped into a `torch.nn.ModuleList`, s.t. the parameters appear as attributes in the hierarchy. Then, `gatherParametersAndBuffers` (https://github.com/pytorch/pytorch/blob/9c43b16df9dad3dfb4da1efab68d8c88e6437e8f/torch/csrc/jit/tracer.cpp#L285) can see these parameters and create a `Value*` for them in the traced graph.
3) In forward, we need to convert from ModuleList -> Module -> Parameter to a simple TensorList of the parameters. We just use a loop here. In tracing, we simply record a `ListConstruct` with each of the proper parameter values. In scripting, the `ModuleList` is const, so it can be unrolled into the graph and a subsequent `ListConstruct` does its business.

The `forward` of the traced LSTM before and after this change are as follows:

Before
```
def forward(self,
    input: Tensor,
    argument_2: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
  hx, hx0, = argument_2
  _0, _1, _2 = torch.quantized_lstm(input, [hx, hx0], [CONSTANTS.c0, CONSTANTS.c1], True, 1, 0., True, False, False, dtype=12, use_dynamic=True)
  return (_0, (_1, _2))
```

After

```
def forward(self,
    input: Tensor,
    argument_2: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
  _0 = self.cell._all_weight_values
  _1 = getattr(_0, "0").param
  _2 = getattr(_0, "1").param
  hx, hx0, = argument_2
  _3, _4, _5 = torch.quantized_lstm(input, [hx, hx0], [_1, _2], True, 1, 0., True, False, False, dtype=12, use_dynamic=True)
  return (_3, (_4, _5))

```

Differential Revision: [D18359880](https://our.internmc.facebook.com/intern/diff/D18359880)

[ghstack-poisoned]
jamesr66a pushed a commit that referenced this pull request Nov 7, 2019
ghstack-source-id: 897155b
Pull Request resolved: #29331
@facebook-github-bot facebook-github-bot deleted the gh/jamesr66a/139/head branch November 11, 2019 15:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants