-
Notifications
You must be signed in to change notification settings - Fork 633
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add conformer w2v2 model architecture (#2826)
Summary: internal comparison tests: D40080919 follow up PR for pretrained models #2827 Pull Request resolved: #2826 Reviewed By: nateanl Differential Revision: D41160061 Pulled By: carolineechen fbshipit-source-id: f3c478b28c235af53d1d8e21b573c53684a63ac4
- Loading branch information
1 parent
bd76d3d
commit 74f9a89
Showing
5 changed files
with
441 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
99 changes: 99 additions & 0 deletions
99
test/torchaudio_unittest/prototype/conformer_wav2vec2_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import torch | ||
from parameterized import parameterized | ||
from torchaudio.prototype.models import conformer_wav2vec2_base | ||
from torchaudio_unittest.common_utils import skipIfNoCuda, torch_script, TorchaudioTestCase | ||
|
||
|
||
class TestConformerWav2Vec2(TorchaudioTestCase): | ||
def _smoke_test(self, model, device, dtype): | ||
model = model.to(device=device, dtype=dtype) | ||
model = model.eval() | ||
|
||
batch_size, num_frames, in_features = 3, 1024, 64 | ||
features = torch.randn(batch_size, num_frames, in_features, device=device, dtype=dtype) | ||
lengths = torch.randint( | ||
low=0, | ||
high=num_frames, | ||
size=[ | ||
batch_size, | ||
], | ||
device=device, | ||
) | ||
|
||
model(features, lengths) | ||
|
||
@parameterized.expand([(torch.float32,), (torch.float64,)]) | ||
def test_cpu_smoke_test(self, dtype): | ||
model = conformer_wav2vec2_base() | ||
self._smoke_test(model, torch.device("cpu"), dtype) | ||
|
||
@parameterized.expand([(torch.float32,), (torch.float64,)]) | ||
@skipIfNoCuda | ||
def test_cuda_smoke_test(self, dtype): | ||
model = conformer_wav2vec2_base() | ||
self._smoke_test(model, torch.device("cuda"), dtype) | ||
|
||
def test_extract_feature(self): | ||
model = conformer_wav2vec2_base() | ||
model.eval() | ||
|
||
batch_size, num_frames, in_features = 3, 1024, 64 | ||
num_layers = len(model.encoder.conformer) | ||
|
||
features = torch.randn(batch_size, num_frames, in_features) | ||
lengths = torch.randint( | ||
low=0, | ||
high=num_frames, | ||
size=[ | ||
batch_size, | ||
], | ||
) | ||
|
||
all_features, lengths_ = model.extract_features(features, lengths, num_layers=None) | ||
assert len(all_features) == num_layers | ||
for feats in all_features: | ||
assert feats.ndim == 3 | ||
assert feats.shape[0] == batch_size | ||
assert lengths_.shape == torch.Size([batch_size]) | ||
|
||
for l in range(1, num_layers + 1): | ||
feats, lengths_ = model.extract_features(features, lengths, num_layers=l) | ||
assert len(feats) == l | ||
for i in range(l): | ||
self.assertEqual(all_features[i], feats[i]) | ||
assert lengths_.shape == torch.Size([batch_size]) | ||
|
||
def test_zero_length(self): | ||
model = conformer_wav2vec2_base() | ||
model.eval() | ||
|
||
batch_size, num_frames, in_features = 3, 1024, 64 | ||
features = torch.randn(batch_size, num_frames, in_features) | ||
input_lengths = torch.zeros(batch_size) | ||
_, output_lengths = model(features, input_lengths) | ||
self.assertEqual(torch.zeros_like(output_lengths), output_lengths) | ||
|
||
_, output_lengths = model.extract_features(features, input_lengths) | ||
self.assertEqual(torch.zeros_like(output_lengths), output_lengths) | ||
|
||
def test_torchscript_consistency(self): | ||
model = conformer_wav2vec2_base() | ||
model.eval() | ||
|
||
batch_size, num_frames, in_features = 3, 1024, 64 | ||
features = torch.randn(batch_size, num_frames, in_features) | ||
lengths = torch.randint( | ||
low=0, | ||
high=num_frames, | ||
size=[ | ||
batch_size, | ||
], | ||
) | ||
|
||
ref_out, ref_len = model(features, lengths) | ||
|
||
scripted = torch_script(model) | ||
hyp_out, hyp_len = scripted(features, lengths) | ||
|
||
self.assertEqual(hyp_out, ref_out) | ||
self.assertEqual(hyp_len, ref_len) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,11 @@ | ||
from ._conformer_wav2vec2 import conformer_wav2vec2_base, conformer_wav2vec2_model | ||
from .conv_emformer import ConvEmformer | ||
from .rnnt import conformer_rnnt_base, conformer_rnnt_model | ||
|
||
__all__ = [ | ||
"conformer_rnnt_base", | ||
"conformer_rnnt_model", | ||
"ConvEmformer", | ||
"conformer_wav2vec2_model", | ||
"conformer_wav2vec2_base", | ||
] |
Oops, something went wrong.