Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add wav2vec2.0 model #1529

Merged
merged 8 commits into from May 27, 2021
Merged

Add wav2vec2.0 model #1529

merged 8 commits into from May 27, 2021

Conversation

mthrok
Copy link
Collaborator

@mthrok mthrok commented May 26, 2021

This PR adds

  • TorchScript-able Wav2Vec2Model class
  • Factory functions for three configurations presented in the paper
    • wav2vec2_base
    • wav2vec2_large
    • wav2vec2_large_lv60k

ref: #1506
supersedes #1525

@mthrok
Copy link
Collaborator Author

mthrok commented May 26, 2021

Quantization tests are failing for macOS CI. But I tried it locally with the latest PyTorch nightly, and it worked fine.

2021-05-25 23:45:43 moto@moto-mbp:/Users/moto/Development/torchaudio  (base)
% pytest test/torchaudio_unittest/models/wav2vec2 -k quant -v
============================================================================================================ test session starts =============================================================================================================
platform darwin -- Python 3.8.5, pytest-6.2.2, py-1.10.0, pluggy-0.13.1 -- /Users/moto/miniconda3/bin/python
cachedir: .pytest_cache
rootdir: /Users/moto/Development/torchaudio
collected 18 items / 12 deselected / 6 selected

test/torchaudio_unittest/models/wav2vec2/model_test.py::TestWav2Vec2Model::test_quantize_0 PASSED                                                                                                                                      [ 16%]
test/torchaudio_unittest/models/wav2vec2/model_test.py::TestWav2Vec2Model::test_quantize_1 PASSED                                                                                                                                      [ 33%]
test/torchaudio_unittest/models/wav2vec2/model_test.py::TestWav2Vec2Model::test_quantize_2 PASSED                                                                                                                                      [ 50%]
test/torchaudio_unittest/models/wav2vec2/model_test.py::TestWav2Vec2Model::test_quantize_torchscript_0 PASSED                                                                                                                          [ 66%]
test/torchaudio_unittest/models/wav2vec2/model_test.py::TestWav2Vec2Model::test_quantize_torchscript_1 PASSED                                                                                                                          [ 83%]
test/torchaudio_unittest/models/wav2vec2/model_test.py::TestWav2Vec2Model::test_quantize_torchscript_2 PASSED                                                                                                                          [100%]

===================================================================================================== 6 passed, 12 deselected in 52.63s ======================================================================================================

@vkuzo Have you ever seen an error like this? I guess it's CI / PyTorch nightly package issue, but if you have an insight, that will be helpful as well.

From https://app.circleci.com/pipelines/github/pytorch/audio/6076/workflows/6958ea54-1bb4-482e-8d7b-c42de0bf0a4b/jobs/216325

self = <[AttributeError("'LinearPackedParams' object has no attribute '_packed_params'",) raised in repr()] LinearPackedParams object at 0x7faa652edef0>
weight = tensor([[0.]], size=(1, 1), dtype=torch.qint8,
       quantization_scheme=torch.per_tensor_affine, scale=1.0, zero_point=0)
bias = None

    @torch.jit.export
    def set_weight_bias(self, weight: torch.Tensor, bias: Optional[torch.Tensor]) -> None:
        if self.dtype == torch.qint8:
>           self._packed_params = torch.ops.quantized.linear_prepack(weight, bias)
E           RuntimeError: Didn't find engine for operation quantized::linear_prepack NoQEngine


shape = (batch_size, length, self.num_heads, self.head_dim)
q = self.q_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
k = self.k_proj(x).view(*shape).permute(0, 2, 3, 1) # B, nH, Hd, L
Copy link
Contributor

@cpuhrsch cpuhrsch May 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Why permute and not transpose?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I need to do transpose twice to achieve this, and I thought permute is more readable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, you're merging the transpose needed for the weights below

)

shape = (batch_size, length, self.num_heads, self.head_dim)
q = self.q_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
Copy link
Contributor

@cpuhrsch cpuhrsch May 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not blocking: All these projections consume the same input, so you could do it using one linear with 3x the embedding dim. You could also call into nn.MHA altogether.

mask = torch.arange(max_len).expand(batch_size, max_len) >= lengths[:, None]
x[mask] = 0.0
# extend the mask to attention shape and set weight
mask = -10000.0 * mask[:, None, None, :].to(dtype=features.dtype)
Copy link
Contributor

@cpuhrsch cpuhrsch May 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not blocking: There's contention around what the right value here is. Parlai uses neginf that depends on the input dtype, nn.MHA uses float(-inf) which likely has issues with lower precision dtypes and FasterTransformer uses -10000 as well, fairseq uses float("-inf") as well. I think the -10000.0 is fine, but I'm curious about your reasoning behind this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This value is from HuggingFaces's implementation of Wav2Vec2.0.

Let me try float("-inf") and if the test passes, then I will switch to float("-inf").

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cannot make it work with float("-inf"), so I will stick with -10000.

@vkuzo
Copy link
Contributor

vkuzo commented May 26, 2021

@vkuzo Have you ever seen an error like this? I guess it's CI / PyTorch nightly package issue, but if you have an insight, that will be helpful as well.

E RuntimeError: Didn't find engine for operation quantized::linear_prepack NoQEngine

This means that neither fbgemm or qnnpack are available in the environment. Quantization is not supported if neither of those are available, unless the user is using a custom out of tree quantization backend. We have seen this before, and usually we just gate the test by availability of fbgemm/qnnpack. You could check out the
override_qengines decorator in torch/testing/_internal/common_quantized.py for a util to do this.

@mthrok mthrok merged commit e6886a4 into master May 27, 2021
@mthrok mthrok deleted the w2v2-pr-model branch May 27, 2021 13:12
@mthrok mthrok restored the w2v2-pr-model branch May 27, 2021 13:14
@mthrok mthrok deleted the w2v2-pr-model branch May 27, 2021 13:15
mthrok pushed a commit to mthrok/audio that referenced this pull request Dec 13, 2022
* Add loss_fn as an input to the test function

The function `test()` depends on `loss_fn` just like `train()` does. It is better to explicitly provide `loss_fn` as an argument instead of relying on a global `test_fn` object. Also, this is more consistent with how `train()` is defined. And it also makes it more explicit that `train()` has a dependency on `optimizer`, while `test()` doesn't.

* Update quickstart_tutorial.py

Co-authored-by: Holly Sweeney <77758406+holly1238@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants