Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conformer training with warp_rnnt #119

Closed
mjurkus opened this issue Mar 9, 2021 · 5 comments
Closed

Conformer training with warp_rnnt #119

mjurkus opened this issue Mar 9, 2021 · 5 comments
Assignees

Comments

@mjurkus
Copy link

mjurkus commented Mar 9, 2021

I was wondering if you could comment why you're dropping 1st target here when training conformer with rnnt?

I'm asking because I got an error when doing the same with https://github.com/sooftware/conformer

    def forward(self, inputs, targets, input_lengths, target_lengths):
        return self.conformer(inputs, input_lengths, targets, target_lengths)


    def training_step(self, batch, batch_idx):
        inputs, targets, input_lengths, target_lengths = batch
        outputs = self.forward(inputs, targets, input_lengths, target_lengths)
        loss = self.criterion(outputs, targets[:, 1:].contiguous().int(), input_lengths.int(), target_lengths.int())

        self.log("loss", loss)

        return loss

where criterion is warp_rnnt.rnnt_loss.

The error is

RuntimeError: The expanded size of the tensor (7) must match the existing size (8) at non-singleton dimension 2.  Target sizes: [4, 47, 7].  Tensor sizes: [4, 1, 8]

which makes sense when you look at the shapes:

input shape torch.Size([4, 191, 80])
outputs shape torch.Size([4, 47, 8, 5000])
targets shape torch.Size([4, 9])
input_lengths shape torch.Size([4])
target_lengths shape torch.Size([4])
@mjurkus
Copy link
Author

mjurkus commented Mar 13, 2021

My configuration looks like this:

audio:
  audio_extension: wav
  sample_rate: 16000
  frame_length: 20
  frame_shift: 10
  normalize: true
  del_silence: false
  time_mask_num: 4
  freq_mask_num: 2
  spec_augment: true
  input_reverse: false
  transform_method: fbank
  n_mels: 80
  freq_mask_para: 18
model:
  architecture: conformer
  dropout: 0.3
  bidirectional: false
  max_len: 400
  feed_forward_expansion_factor: 4
  conv_expansion_factor: 2
  input_dropout_p: 0.1
  feed_forward_dropout_p: 0.1
  attention_dropout_p: 0.1
  conv_dropout_p: 0.1
  decoder_dropout_p: 0.1
  conv_kernel_size: 31
  half_step_residual: true
  num_decoder_layers: 1
  decoder_rnn_type: lstm
  encoder_dim: 144
  decoder_dim: 320
  num_encoder_layers: 16
  num_attention_heads: 4

@sooftware
Copy link
Owner

Sorry for late response. I'm very busy these days. First drop means that drop the sos token. It looks like a code correction is needed. But I've been busy lately so I can't afford it.

@mjurkus
Copy link
Author

mjurkus commented Mar 14, 2021

What should be fixed? I could look into it.

@sooftware
Copy link
Owner

Sorry for too late response. How about train with ctc loss? (not transducer) I think there's a bug now, but I'm busy and can't afford to fix it.

@mjurkus
Copy link
Author

mjurkus commented Mar 29, 2021

I've looked into training with CTC, but there was something wrong too - it crashed too, but now I do not have the error. I'll try to reproduce.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants