-
Notifications
You must be signed in to change notification settings - Fork 633
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add model Wav2Letter #462
Merged
Merged
Add model Wav2Letter #462
Changes from 10 commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
c7d6ece
add wav2letter model
tomassosorio 17b7078
add unit_test to model
tomassosorio 48c312b
add docstrings
tomassosorio aea6fad
add documentation
tomassosorio 5b68bab
fix minor error, change logic on forward
tomassosorio 064e923
update padding same with ceil
tomassosorio 0a7f0bf
Merge branch 'master' into addModelWav2Letter
tomassosorio 390e176
add inline typing and minor fixes to docstrings
tomassosorio 48df4fb
remove python2
tomassosorio 2fc356b
add formula do docstrings, change param name
tomassosorio 895a2e6
Merge branch 'master' into addModelWav2Letter
tomassosorio d4fb114
Merge branch 'master' into addModelWav2Letter
tomassosorio 34785ec
add test with mfcc, add pytest
tomassosorio 63def38
fix bug, update docstrings
tomassosorio 9f9e79d
change parameter name
tomassosorio File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
.. role:: hidden | ||
:class: hidden-section | ||
|
||
torchaudio.models | ||
====================== | ||
|
||
.. currentmodule:: torchaudio.models | ||
|
||
The models subpackage contains definitions of models for addressing common audio tasks. | ||
|
||
|
||
:hidden:`Wav2Letter` | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: Wav2Letter | ||
|
||
.. automethod:: forward |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import unittest | ||
|
||
import torch | ||
from torchaudio.models import Wav2Letter | ||
|
||
|
||
class ModelTester(unittest.TestCase): | ||
def test_wav2letter(self): | ||
batch_size = 2 | ||
n_features = 1 | ||
input_length = 320 | ||
|
||
model = Wav2Letter() | ||
x = torch.rand(batch_size, n_features, input_length) | ||
out = model(x) | ||
|
||
assert out.size() == (2, batch_size, 40) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .wav2letter import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from typing import Optional | ||
|
||
from torch import Tensor | ||
from torch import nn | ||
|
||
__all__ = ["Wav2Letter"] | ||
|
||
|
||
class Wav2Letter(nn.Module): | ||
r"""Wav2Letter model architecture from the `"Wav2Letter: an End-to-End ConvNet-based Speech Recognition System" | ||
<https://arxiv.org/abs/1609.03193>`_ paper. | ||
|
||
:math:`\text{padding} = \frac{\text{ceil}(\text{kernel} - \text{stride})}{2}` | ||
|
||
Args: | ||
num_classes (int, optional): Number of classes to be classified. (Default: ``40``) | ||
version (str, optional): Wav2Letter can use as input: ``waveform``, ``power_spectrum`` | ||
or ``mfcc`` (Default: ``waveform``). | ||
num_features (int, optional): Number of input features that the network will receive (Default: ``1``). | ||
""" | ||
|
||
def __init__(self, num_classes: int = 40, | ||
version: str = "waveform", | ||
num_features: Optional[int] = 1) -> None: | ||
super(Wav2Letter, self).__init__() | ||
|
||
acoustic_model = nn.Sequential( | ||
nn.Conv1d(in_channels=num_features, out_channels=250, kernel_size=48, stride=2, padding=23), | ||
nn.ReLU(inplace=True), | ||
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), | ||
nn.ReLU(inplace=True), | ||
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), | ||
nn.ReLU(inplace=True), | ||
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), | ||
nn.ReLU(inplace=True), | ||
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), | ||
nn.ReLU(inplace=True), | ||
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), | ||
nn.ReLU(inplace=True), | ||
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), | ||
nn.ReLU(inplace=True), | ||
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), | ||
nn.ReLU(inplace=True), | ||
nn.Conv1d(in_channels=250, out_channels=2000, kernel_size=32, stride=1, padding=16), | ||
nn.ReLU(inplace=True), | ||
nn.Conv1d(in_channels=2000, out_channels=2000, kernel_size=1, stride=1, padding=0), | ||
nn.ReLU(inplace=True), | ||
nn.Conv1d(in_channels=2000, out_channels=num_classes, kernel_size=1, stride=1, padding=0), | ||
nn.ReLU(inplace=True) | ||
) | ||
|
||
if version == "waveform": | ||
waveform_model = nn.Sequential( | ||
nn.Conv1d(in_channels=num_features, out_channels=250, kernel_size=250, stride=160, padding=45), | ||
nn.ReLU(inplace=True) | ||
) | ||
self.acoustic_model = nn.Sequential(waveform_model, acoustic_model) | ||
|
||
if version in ["power_spectrum", "mfcc"]: | ||
self.acoustic_model = acoustic_model | ||
|
||
def forward(self, x: Tensor) -> Tensor: | ||
r""" | ||
Args: | ||
x (Tensor): Tensor of dimension (batch_size, n_features, input_length). | ||
tomassosorio marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Returns: | ||
Tensor: Predictor tensor of dimension (input_length, batch_size, number_of_classes). | ||
tomassosorio marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
|
||
x = self.acoustic_model(x) | ||
x = nn.functional.log_softmax(x, dim=1) | ||
return x |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we extend these tests further (i.e. specific input / output tests for fixed weights) or is there little value in it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For what I saw in torchvision they had a problem with doing extensive test due to slowing Travis.
The way of the test that I added is that the model should give a letter every 20ms, since the sample rate that they use is 16000, 20ms would be 320 points, however, because of the padding at the 10-th layer it would add 1 extra letter as output.
I used padding same with formula padding = ceil(kernel-stride)/2 due to wav2letter++ issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like this should be mentioned in the documentation.