-
Notifications
You must be signed in to change notification settings - Fork 633
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add wav2letter model * add unit_test to model * add docstrings * add documentation * fix minor error, change logic on forward * update padding same with ceil * add inline typing and minor fixes to docstrings * remove python2 * add formula do docstrings, change param name * add test with mfcc, add pytest * fix bug, update docstrings * change parameter name
- Loading branch information
1 parent
3ecc701
commit d678357
Showing
4 changed files
with
122 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
.. role:: hidden | ||
:class: hidden-section | ||
|
||
torchaudio.models | ||
====================== | ||
|
||
.. currentmodule:: torchaudio.models | ||
|
||
The models subpackage contains definitions of models for addressing common audio tasks. | ||
|
||
|
||
:hidden:`Wav2Letter` | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: Wav2Letter | ||
|
||
.. automethod:: forward |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import pytest | ||
|
||
import torch | ||
from torchaudio.models import Wav2Letter | ||
|
||
|
||
class TestWav2Letter: | ||
@pytest.mark.parametrize('batch_size', [2]) | ||
@pytest.mark.parametrize('num_features', [1]) | ||
@pytest.mark.parametrize('num_classes', [40]) | ||
@pytest.mark.parametrize('input_length', [320]) | ||
def test_waveform(self, batch_size, num_features, num_classes, input_length): | ||
model = Wav2Letter() | ||
|
||
x = torch.rand(batch_size, num_features, input_length) | ||
out = model(x) | ||
|
||
assert out.size() == (batch_size, num_classes, 2) | ||
|
||
@pytest.mark.parametrize('batch_size', [2]) | ||
@pytest.mark.parametrize('num_features', [13]) | ||
@pytest.mark.parametrize('num_classes', [40]) | ||
@pytest.mark.parametrize('input_length', [2]) | ||
def test_mfcc(self, batch_size, num_features, num_classes, input_length): | ||
model = Wav2Letter(input_type="mfcc", num_features=13) | ||
|
||
x = torch.rand(batch_size, num_features, input_length) | ||
out = model(x) | ||
|
||
assert out.size() == (batch_size, num_classes, 2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .wav2letter import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from typing import Optional | ||
|
||
from torch import Tensor | ||
from torch import nn | ||
|
||
__all__ = ["Wav2Letter"] | ||
|
||
|
||
class Wav2Letter(nn.Module): | ||
r"""Wav2Letter model architecture from the `"Wav2Letter: an End-to-End ConvNet-based Speech Recognition System" | ||
<https://arxiv.org/abs/1609.03193>`_ paper. | ||
:math:`\text{padding} = \frac{\text{ceil}(\text{kernel} - \text{stride})}{2}` | ||
Args: | ||
num_classes (int, optional): Number of classes to be classified. (Default: ``40``) | ||
input_type (str, optional): Wav2Letter can use as input: ``waveform``, ``power_spectrum`` | ||
or ``mfcc`` (Default: ``waveform``). | ||
num_features (int, optional): Number of input features that the network will receive (Default: ``1``). | ||
""" | ||
|
||
def __init__(self, num_classes: int = 40, | ||
input_type: str = "waveform", | ||
num_features: int = 1) -> None: | ||
super(Wav2Letter, self).__init__() | ||
|
||
acoustic_num_features = 250 if input_type == "waveform" else num_features | ||
acoustic_model = nn.Sequential( | ||
nn.Conv1d(in_channels=acoustic_num_features, out_channels=250, kernel_size=48, stride=2, padding=23), | ||
nn.ReLU(inplace=True), | ||
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), | ||
nn.ReLU(inplace=True), | ||
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), | ||
nn.ReLU(inplace=True), | ||
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), | ||
nn.ReLU(inplace=True), | ||
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), | ||
nn.ReLU(inplace=True), | ||
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), | ||
nn.ReLU(inplace=True), | ||
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), | ||
nn.ReLU(inplace=True), | ||
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), | ||
nn.ReLU(inplace=True), | ||
nn.Conv1d(in_channels=250, out_channels=2000, kernel_size=32, stride=1, padding=16), | ||
nn.ReLU(inplace=True), | ||
nn.Conv1d(in_channels=2000, out_channels=2000, kernel_size=1, stride=1, padding=0), | ||
nn.ReLU(inplace=True), | ||
nn.Conv1d(in_channels=2000, out_channels=num_classes, kernel_size=1, stride=1, padding=0), | ||
nn.ReLU(inplace=True) | ||
) | ||
|
||
if input_type == "waveform": | ||
waveform_model = nn.Sequential( | ||
nn.Conv1d(in_channels=num_features, out_channels=250, kernel_size=250, stride=160, padding=45), | ||
nn.ReLU(inplace=True) | ||
) | ||
self.acoustic_model = nn.Sequential(waveform_model, acoustic_model) | ||
|
||
if input_type in ["power_spectrum", "mfcc"]: | ||
self.acoustic_model = acoustic_model | ||
|
||
def forward(self, x: Tensor) -> Tensor: | ||
r""" | ||
Args: | ||
x (Tensor): Tensor of dimension (batch_size, num_features, input_length). | ||
Returns: | ||
Tensor: Predictor tensor of dimension (batch_size, number_of_classes, input_length). | ||
""" | ||
|
||
x = self.acoustic_model(x) | ||
x = nn.functional.log_softmax(x, dim=1) | ||
return x |