Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model Wav2Letter #462

Merged
merged 15 commits into from Apr 28, 2020
Merged

Conversation

tomassosorio
Copy link
Contributor

@tomassosorio tomassosorio commented Mar 11, 2020

  • Add Documentation related to module models
  • Add documentation to model wav2letter
  • Add Unit-test to wav2-letter
  • Add model Wav2Letter according to its paper wav2letter

Relates to #446



class ModelTester(unittest.TestCase):
def test_wav2letter(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we extend these tests further (i.e. specific input / output tests for fixed weights) or is there little value in it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For what I saw in torchvision they had a problem with doing extensive test due to slowing Travis.

The way of the test that I added is that the model should give a letter every 20ms, since the sample rate that they use is 16000, 20ms would be 320 points, however, because of the padding at the 10-th layer it would add 1 extra letter as output.

I used padding same with formula padding = ceil(kernel-stride)/2 due to wav2letter++ issue

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used padding same with formula padding = ceil(kernel-stride)/2 due to wav2letter++ issue

I feel like this should be mentioned in the documentation.

@cpuhrsch
Copy link
Contributor

As a generic reminder: Since this creates a whole new top-level folder and design (collection of models) I think it's worth pausing to make sure we're not setting us up for some hard-to-reverse decisions down the road.

Copy link
Contributor Author

@tomassosorio tomassosorio left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a generic reminder: Since this creates a whole new top-level folder and design (collection of models) I think it's worth pausing to make sure we're not setting us up for some hard-to-reverse decisions down the road.

I totally agree I suggested at issue #446 since it was no blockers I proceeded to do a PR, but we can discuss further if it would be beneficial and how it would be implemented if so.

@vincentqb
Copy link
Contributor

As a generic reminder: Since this creates a whole new top-level folder and design (collection of models) I think it's worth pausing to make sure we're not setting us up for some hard-to-reverse decisions down the road.

Yes, we need to make sure we think the interface correctly. This provides a nice forcing function for us to do so. :)

torchaudio/models/wav2letter.py Outdated Show resolved Hide resolved
torchaudio/models/wav2letter.py Outdated Show resolved Hide resolved


class ModelTester(unittest.TestCase):
def test_wav2letter(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used padding same with formula padding = ceil(kernel-stride)/2 due to wav2letter++ issue

I feel like this should be mentioned in the documentation.

test/test_models.py Outdated Show resolved Hide resolved
torchaudio/models/wav2letter.py Outdated Show resolved Hide resolved
torchaudio/models/wav2letter.py Outdated Show resolved Hide resolved
Copy link
Contributor

@vincentqb vincentqb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, this looks good to me. Do we have other models we could consider implementing? This would help thinking about whether the interface is general enough while still being simple.

One open point is to ensure correctness when someone suggests a new model. The tests added here are very limited, and testing that the implementation is right is not obvious.

torchaudio/models/wav2letter.py Outdated Show resolved Hide resolved
x (torch.Tensor): Tensor of dimension (batch_size, n_features, input_length).

Returns:
torch.Tensor: Predictor tensor of dimension (input_length, batch_size, number_of_classes).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we were initializing the module lazily, we could infer the number of features. Maybe a factory function could help with that? I don't see it done with torchvision though, so I won't be advocating this for this PR. :)

torchaudio/models/wav2letter.py Outdated Show resolved Hide resolved
torchaudio/models/wav2letter.py Outdated Show resolved Hide resolved
@vincentqb
Copy link
Contributor

vincentqb commented Apr 2, 2020

One open point is to ensure correctness when someone suggests a new model. The tests added here are very limited, and testing that the implementation is right is not obvious.

Is there an invariant that could be verified? say the shape remains the same at each layers?

We could also think of an optional test that runs a simple convergence test on a standard dataset. This could get expensive though.

@vincentqb
Copy link
Contributor

Test is failling:

_________________________ ModelTester.test_wav2letter __________________________
self = <test_models.ModelTester testMethod=test_wav2letter>
    def test_wav2letter(self):
        batch_size = 2
        n_features = 1
        input_length = 320
    
        model = Wav2Letter()
        x = torch.rand(batch_size, n_features, input_length)
>       out = model(x)
test/test_models.py:15: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../../../miniconda3/envs/testenv/lib/python3.7/site-packages/torch/nn/modules/module.py:532: in __call__
    result = self.forward(*input, **kwargs)
torchaudio/models/wav2letter.py:71: in forward
    x = self.acoustic_model(x)
../../../miniconda3/envs/testenv/lib/python3.7/site-packages/torch/nn/modules/module.py:532: in __call__
    result = self.forward(*input, **kwargs)
../../../miniconda3/envs/testenv/lib/python3.7/site-packages/torch/nn/modules/container.py:100: in forward
    input = module(input)
../../../miniconda3/envs/testenv/lib/python3.7/site-packages/torch/nn/modules/module.py:532: in __call__
    result = self.forward(*input, **kwargs)
../../../miniconda3/envs/testenv/lib/python3.7/site-packages/torch/nn/modules/container.py:100: in forward
    input = module(input)
../../../miniconda3/envs/testenv/lib/python3.7/site-packages/torch/nn/modules/module.py:532: in __call__
    result = self.forward(*input, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
self = Conv1d(1, 250, kernel_size=(48,), stride=(2,), padding=(23,))
input = tensor([[[0.0000, 0.0000],
         [0.0000, 0.0000],
         [0.0000, 0.0000],
         [0.0000, 0.0000],
         [....0000],
         [0.0000, 0.0973],
         [0.3981, 0.2518],
         [0.4468, 0.1599]]], grad_fn=<AsStridedBackward>)
    def forward(self, input):
        if self.padding_mode == 'circular':
            expanded_padding = ((self.padding[0] + 1) // 2, self.padding[0] // 2)
            return F.conv1d(F.pad(input, expanded_padding, mode='circular'),
                            self.weight, self.bias, self.stride,
                            _single(0), self.dilation, self.groups)
        return F.conv1d(input, self.weight, self.bias, self.stride,
>                       self.padding, self.dilation, self.groups)
E       RuntimeError: Given groups=1, weight of size 250 1 48, expected input[2, 250, 2] to have 1 channels, but got 250 channels instead
../../../miniconda3/envs/testenv/lib/python3.7/site-packages/torch/nn/modules/conv.py:202: RuntimeError

torchaudio/models/wav2letter.py Outdated Show resolved Hide resolved
torchaudio/models/wav2letter.py Outdated Show resolved Hide resolved
@vincentqb
Copy link
Contributor

Overall, this looks very good to me, and serves as a great template for other models to come :)

@tomassosorio
Copy link
Contributor Author

Overall, this looks very good to me, and serves as a great template for other models to come :)

Thanks! I will try to take a look today! Did not have the time yet :/

@tomassosorio
Copy link
Contributor Author

tomassosorio commented Apr 27, 2020

@vincentqb Sorry for the delay, I was a bit busy.
If you could take a look I would appreciate :)

Used pytest instead of unittest since I think torchaudio is moving towards pytest for what I saw

also changed parameter version to input_type since the same model can have different versions, however wave_type or wave_shape might be a name than input_type?

Copy link
Contributor

@vincentqb vincentqb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for setting this up, and looking forward to adding more models :)

We haven't yet made a final decision about pytest. Since the tests here are easy to convert either way, I will merge anyway.

@vincentqb vincentqb merged commit d678357 into pytorch:master Apr 28, 2020
@tomassosorio tomassosorio deleted the addModelWav2Letter branch April 28, 2020 21:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants