Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model Wav2Letter #462

Merged
merged 15 commits into from
Apr 28, 2020
Merged
17 changes: 17 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
.. role:: hidden
:class: hidden-section

torchaudio.models
======================

.. currentmodule:: torchaudio.models

The models subpackage contains definitions of models for addressing common audio tasks.


:hidden:`Wav2Letter`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: Wav2Letter

.. automethod:: forward
22 changes: 22 additions & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from __future__ import absolute_import, division, print_function, unicode_literals
tomassosorio marked this conversation as resolved.
Show resolved Hide resolved
import unittest

import torch
from torchaudio.models import wav2letter


class ModelTester(unittest.TestCase):
def test_wav2letter(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we extend these tests further (i.e. specific input / output tests for fixed weights) or is there little value in it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For what I saw in torchvision they had a problem with doing extensive test due to slowing Travis.

The way of the test that I added is that the model should give a letter every 20ms, since the sample rate that they use is 16000, 20ms would be 320 points, however, because of the padding at the 10-th layer it would add 1 extra letter as output.

I used padding same with formula padding = ceil(kernel-stride)/2 due to wav2letter++ issue

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used padding same with formula padding = ceil(kernel-stride)/2 due to wav2letter++ issue

I feel like this should be mentioned in the documentation.

batch_size = 2
n_features = 1
input_length = 800

model = wav2letter()
x = torch.rand(batch_size, n_features, input_length)
out = model(x)

assert out.size() == (1, batch_size, 40)


if __name__ == '__main__':
unittest.main()
1 change: 1 addition & 0 deletions torchaudio/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .wav2letter import *
79 changes: 79 additions & 0 deletions torchaudio/models/wav2letter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import torch
from torch import nn

__all__ = ["Wav2Letter", "wav2letter"]


class Wav2Letter(nn.Module):
r"""Wav2Letter model architecture from the `"Wav2Letter: an End-to-End ConvNet-based Speech Recognition System"
<https://arxiv.org/abs/1609.03193>`_ paper.

Args:
num_classes (int, optional): Number of classes to be classified. (Default: ``40``)
version (str, optional): Wav2Letter can use as input: ``waveform``, ```power_spectrum``
or ```mfcc``. (Default: ``waveform``)
tomassosorio marked this conversation as resolved.
Show resolved Hide resolved
n_input_features (int, optional): If is used ``power_spectrum`` or ```mfcc`` must be
specified the number of features that were extracted. (Default: ``None``)
tomassosorio marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self, num_classes=40, version="waveform", n_input_features=None):
tomassosorio marked this conversation as resolved.
Show resolved Hide resolved
super(Wav2Letter, self).__init__()
n_input_features = 250 if not n_input_features else n_input_features
tomassosorio marked this conversation as resolved.
Show resolved Hide resolved

acoustic_model = nn.Sequential(
nn.Conv1d(in_channels=n_input_features, out_channels=250, kernel_size=48, stride=2, padding=23),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=250, out_channels=2000, kernel_size=32, stride=1, padding=15),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=2000, out_channels=2000, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=2000, out_channels=num_classes, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True)
)

if version == "waveform":
waveform_model = nn.Sequential(
nn.Conv1d(in_channels=1, out_channels=250, kernel_size=250, stride=160, padding=45),
nn.ReLU(inplace=True)
)
self.acoustic_model = nn.Sequential(waveform_model, acoustic_model)

if version in ["power_spectrum", "mfcc"]:
self.acoustic_model = acoustic_model

def forward(self, x):
# type: (Tensor) -> Tensor
tomassosorio marked this conversation as resolved.
Show resolved Hide resolved
r"""
Args:
x (torch.Tensor): Tensor of dimension (batch_size, n_features, input_length).

Returns:
torch.Tensor: Predictor tensor of dimension (input_length, batch_size, number_of_classes).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we were initializing the module lazily, we could infer the number of features. Maybe a factory function could help with that? I don't see it done with torchvision though, so I won't be advocating this for this PR. :)

"""

x = self.acoustic_model(x)
x = nn.functional.log_softmax(x, dim=1)
x = x.permute(2, 0, 1)
tomassosorio marked this conversation as resolved.
Show resolved Hide resolved
return x


def wav2letter(**kwargs):
vincentqb marked this conversation as resolved.
Show resolved Hide resolved
r"""Wav2Letter model architecture from the `"Wav2Letter: an End-to-End ConvNet-based Speech Recognition System"
<https://arxiv.org/abs/1609.03193>`_ paper.
"""
model = Wav2Letter(**kwargs)
return model