Skip to content

Commit

Permalink
Add conformer w2v2 model architecture (#2826)
Browse files Browse the repository at this point in the history
Summary:
internal comparison tests: D40080919

follow up PR for pretrained models #2827

Pull Request resolved: #2826

Reviewed By: nateanl

Differential Revision: D41160061

Pulled By: carolineechen

fbshipit-source-id: f3c478b28c235af53d1d8e21b573c53684a63ac4
  • Loading branch information
Caroline Chen authored and facebook-github-bot committed Nov 10, 2022
1 parent bd76d3d commit 74f9a89
Show file tree
Hide file tree
Showing 5 changed files with 441 additions and 0 deletions.
10 changes: 10 additions & 0 deletions docs/source/prototype.models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,13 @@ ConvEmformer
.. automethod:: forward

.. automethod:: infer

conformer_wav2vec2_model
~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: conformer_wav2vec2_model

conformer_wav2vec2_base
~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: conformer_wav2vec2_base
9 changes: 9 additions & 0 deletions docs/source/refs.bib
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,15 @@ @article{coucke2018snips
journal={arXiv preprint arXiv:1805.10190},
year={2018}
}
@INPROCEEDINGS{9746490,
author={Srivastava, Sangeeta and Wang, Yun and Tjandra, Andros and Kumar, Anurag and Liu, Chunxi and Singh, Kritika and Saraf, Yatharth},
booktitle={ICASSP 2022 - 2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
title={Conformer-Based Self-Supervised Learning For Non-Speech Audio Tasks},
year={2022},
volume={},
number={},
pages={8862-8866},
doi={10.1109/ICASSP43922.2022.9746490}}
@article{chen2022wavlm,
title={Wavlm: Large-scale self-supervised pre-training for full stack speech processing},
author={Chen, Sanyuan and Wang, Chengyi and Chen, Zhengyang and Wu, Yu and Liu, Shujie and Chen, Zhuo and Li, Jinyu and Kanda, Naoyuki and Yoshioka, Takuya and Xiao, Xiong and others},
Expand Down
99 changes: 99 additions & 0 deletions test/torchaudio_unittest/prototype/conformer_wav2vec2_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import torch
from parameterized import parameterized
from torchaudio.prototype.models import conformer_wav2vec2_base
from torchaudio_unittest.common_utils import skipIfNoCuda, torch_script, TorchaudioTestCase


class TestConformerWav2Vec2(TorchaudioTestCase):
def _smoke_test(self, model, device, dtype):
model = model.to(device=device, dtype=dtype)
model = model.eval()

batch_size, num_frames, in_features = 3, 1024, 64
features = torch.randn(batch_size, num_frames, in_features, device=device, dtype=dtype)
lengths = torch.randint(
low=0,
high=num_frames,
size=[
batch_size,
],
device=device,
)

model(features, lengths)

@parameterized.expand([(torch.float32,), (torch.float64,)])
def test_cpu_smoke_test(self, dtype):
model = conformer_wav2vec2_base()
self._smoke_test(model, torch.device("cpu"), dtype)

@parameterized.expand([(torch.float32,), (torch.float64,)])
@skipIfNoCuda
def test_cuda_smoke_test(self, dtype):
model = conformer_wav2vec2_base()
self._smoke_test(model, torch.device("cuda"), dtype)

def test_extract_feature(self):
model = conformer_wav2vec2_base()
model.eval()

batch_size, num_frames, in_features = 3, 1024, 64
num_layers = len(model.encoder.conformer)

features = torch.randn(batch_size, num_frames, in_features)
lengths = torch.randint(
low=0,
high=num_frames,
size=[
batch_size,
],
)

all_features, lengths_ = model.extract_features(features, lengths, num_layers=None)
assert len(all_features) == num_layers
for feats in all_features:
assert feats.ndim == 3
assert feats.shape[0] == batch_size
assert lengths_.shape == torch.Size([batch_size])

for l in range(1, num_layers + 1):
feats, lengths_ = model.extract_features(features, lengths, num_layers=l)
assert len(feats) == l
for i in range(l):
self.assertEqual(all_features[i], feats[i])
assert lengths_.shape == torch.Size([batch_size])

def test_zero_length(self):
model = conformer_wav2vec2_base()
model.eval()

batch_size, num_frames, in_features = 3, 1024, 64
features = torch.randn(batch_size, num_frames, in_features)
input_lengths = torch.zeros(batch_size)
_, output_lengths = model(features, input_lengths)
self.assertEqual(torch.zeros_like(output_lengths), output_lengths)

_, output_lengths = model.extract_features(features, input_lengths)
self.assertEqual(torch.zeros_like(output_lengths), output_lengths)

def test_torchscript_consistency(self):
model = conformer_wav2vec2_base()
model.eval()

batch_size, num_frames, in_features = 3, 1024, 64
features = torch.randn(batch_size, num_frames, in_features)
lengths = torch.randint(
low=0,
high=num_frames,
size=[
batch_size,
],
)

ref_out, ref_len = model(features, lengths)

scripted = torch_script(model)
hyp_out, hyp_len = scripted(features, lengths)

self.assertEqual(hyp_out, ref_out)
self.assertEqual(hyp_len, ref_len)
3 changes: 3 additions & 0 deletions torchaudio/prototype/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from ._conformer_wav2vec2 import conformer_wav2vec2_base, conformer_wav2vec2_model
from .conv_emformer import ConvEmformer
from .rnnt import conformer_rnnt_base, conformer_rnnt_model

__all__ = [
"conformer_rnnt_base",
"conformer_rnnt_model",
"ConvEmformer",
"conformer_wav2vec2_model",
"conformer_wav2vec2_base",
]

0 comments on commit 74f9a89

Please sign in to comment.