Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deep Speech 트레인되지 않음. #94

Closed
switiz opened this issue Jan 25, 2021 · 4 comments
Closed

Deep Speech 트레인되지 않음. #94

switiz opened this issue Jan 25, 2021 · 4 comments

Comments

@switiz
Copy link
Contributor

switiz commented Jan 25, 2021

Title

Model Deep Speech 선택하여 학습시 아래 에러 코드 발생하며 학습되지 않습니다.
3번 조치후 정상동작하는것을 보니 data flow를 확인해봐야할것 같네요.

조치 내용들

  1. Fork후 CMD(python main.py model=ds2 train=ds2_train train.dataset_path=C:\SpeechRecognitionDataset\Dataset\AI_hub)
    입력시 module not found(kospeech) error발생하여 sys.path추가함
  2. dataset 경로가 맞지 않아 절대경로로 추가함
  3. model_forward와 pack_padded_sequence 에서 cpu로 변환하여 입력시 정상 동작함.
    -> 370step쯤 error발생
      elif self.architecture in ('deepspeech2', 'jasper'):
      cpu_input_lengths = input_lengths.to(torch.device('cpu'))
      output, output_lengths = model(inputs, input_lengths)
      loss = self.criterion(output.transpose(0, 1), targets, output_lengths, target_lengths)
def forward(self, inputs: Tensor, input_lengths: Tensor):
    total_length = inputs.size(0)

    inputs = F.relu(self.batch_norm(inputs.transpose(1, 2)))
    inputs = inputs.transpose(1, 2)

    cpu_input = input_lengths.to(torch.device('cpu'))
    output = nn.utils.rnn.pack_padded_sequence(inputs, cpu_input )

Description

error

C:\Users\sanma\Documents\GitHub\KoSpeech\bin>python main.py model=ds2 train=ds2_train train.dataset_path=C:\SpeechRecognitionDataset\Dataset\AI_hub

C:\ProgramData\Anaconda3\envs\xray_pytorch\lib\site-packages\hydra\plugins\config_source.py:190: UserWarning:
Missing @Package directive audio/fbank.yaml in file://C:\Users\sanma\Documents\GitHub\KoSpeech\configs.
See https://hydra.cc/docs/next/upgrades/0.11_to_1.0/adding_a_package_directive
warnings.warn(message=msg, category=UserWarning)
C:\ProgramData\Anaconda3\envs\xray_pytorch\lib\site-packages\hydra\plugins\config_source.py:190: UserWarning:
Missing @Package directive model/ds2.yaml in file://C:\Users\sanma\Documents\GitHub\KoSpeech\configs.
See https://hydra.cc/docs/next/upgrades/0.11_to_1.0/adding_a_package_directive
warnings.warn(message=msg, category=UserWarning)
C:\ProgramData\Anaconda3\envs\xray_pytorch\lib\site-packages\hydra\plugins\config_source.py:190: UserWarning:
Missing @Package directive train/ds2_train.yaml in file://C:\Users\sanma\Documents\GitHub\KoSpeech\configs.
See https://hydra.cc/docs/next/upgrades/0.11_to_1.0/adding_a_package_directive
warnings.warn(message=msg, category=UserWarning)
[2021-01-25 23:17:18,098][kospeech.utils][INFO] - audio:
audio_extension: pcm
sample_rate: 16000
frame_length: 20
frame_shift: 10
normalize: true
del_silence: true
feature_extract_by: kaldi
time_mask_num: 4
freq_mask_num: 2
spec_augment: true
input_reverse: false
transform_method: fbank
n_mels: 80
freq_mask_para: 18
audio_extension: pcm
transform_method: fbank
sample_rate: 16000
frame_length: 20
frame_shift: 10
n_mels: 80
normalize: true
del_silence: true
feature_extract_by: kaldi
freq_mask_para: 18
time_mask_num: 4
freq_mask_num: 2
spec_augment: true
input_reverse: false
model:
architecture: deepspeech2
teacher_forcing_ratio: 1.0
teacher_forcing_step: 0.01
min_teacher_forcing_ratio: 0.9
dropout: 0.3
bidirectional: false
joint_ctc_attention: false
max_len: 400
use_bidirectional: true
rnn_type: gru
hidden_dim: 1024
activation: hardtanh
num_encoder_layers: 3
architecture: deepspeech2
use_bidirectional: true
hidden_dim: 1024
dropout: 0.3
num_encoder_layers: 3
rnn_type: gru
max_len: 400
activation: hardtanh
teacher_forcing_ratio: 1.0
teacher_forcing_step: 0.0
min_teacher_forcing_ratio: 1.0
joint_ctc_attention: false
train:
dataset: kspon
dataset_path: C:\SpeechRecognitionDataset\Dataset\AI_hub
transcripts_path: C:/Users/sanma/Documents/GitHub/KoSpeech/data/transcripts.txt
output_unit: character
batch_size: 32
save_result_every: 1000
checkpoint_every: 5000
print_every: 10
mode: train
num_workers: 4
use_cuda: true
init_lr_scale: 0.01
final_lr_scale: 0.05
max_grad_norm: 400
weight_decay: 1.0e-05
seed: 777
resume: false
optimizer: adam
init_lr: 1.0e-06
final_lr: 1.0e-06
peak_lr: 0.0001
warmup_steps: 1000
num_epochs: 70
reduction: mean
dataset: kspon
dataset_path: ''
transcripts_path: ../../../data/transcripts.txt
output_unit: character
num_epochs: 70
batch_size: 32
save_result_every: 1000
checkpoint_every: 5000
print_every: 10
mode: train
seed: 777
resume: false
num_workers: 4
use_cuda: true
optimizer: adam
init_lr: 1.0e-06
final_lr: 1.0e-06
peak_lr: 0.0001
init_lr_scale: 0.01
final_lr_scale: 0.05
max_grad_norm: 400
warmup_steps: 400
weight_decay: 1.0e-05
reduction: mean

[2021-01-25 23:17:18,136][kospeech.utils][INFO] - Operating System : Windows 10
[2021-01-25 23:17:18,136][kospeech.utils][INFO] - Processor : AMD64 Family 23 Model 113 Stepping 0, AuthenticAMD
[2021-01-25 23:17:18,137][kospeech.utils][INFO] - device : GeForce GTX 1070
[2021-01-25 23:17:18,138][kospeech.utils][INFO] - CUDA is available : True
[2021-01-25 23:17:18,138][kospeech.utils][INFO] - CUDA version : 10.2
[2021-01-25 23:17:18,138][kospeech.utils][INFO] - PyTorch version : 1.7.1
C:\Users\sanma\Documents\GitHub\KoSpeech\data\vocab\aihub_character_vocabs.csv
[2021-01-25 23:17:18,141][kospeech.utils][INFO] - split dataset start !!
[2021-01-25 23:17:21,310][kospeech.utils][INFO] - Applying Spec Augmentation...
[2021-01-25 23:17:22,020][kospeech.utils][INFO] - Applying Spec Augmentation...
[2021-01-25 23:17:22,865][kospeech.utils][INFO] - Applying Spec Augmentation...
[2021-01-25 23:17:23,716][kospeech.utils][INFO] - Applying Spec Augmentation...
[2021-01-25 23:17:24,438][kospeech.utils][INFO] - split dataset complete !!
[2021-01-25 23:17:26,611][kospeech.utils][INFO] - start
[2021-01-25 23:17:26,611][kospeech.utils][INFO] - Epoch 0 start
Traceback (most recent call last):
File "main.py", line 170, in main
last_model_checkpoint = train(config)
File "main.py", line 138, in train
model = trainer.train(
File "C:\Users\sanma\Documents\GitHub\KoSpeech\kospeech\trainer\supervised_trainer.py", line 155, in train
model, train_loss, train_cer = self.__train_epoches(
File "C:\Users\sanma\Documents\GitHub\KoSpeech\kospeech\trainer\supervised_trainer.py", line 242, in __train_epoches
output, loss, ctc_loss, cross_entropy_loss = self.model_forward(
File "C:\Users\sanma\Documents\GitHub\KoSpeech\kospeech\trainer\supervised_trainer.py", line 410, in model_forward
output, output_lengths = model(inputs, input_lengths)
File "C:\ProgramData\Anaconda3\envs\xray_pytorch\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "C:\ProgramData\Anaconda3\envs\xray_pytorch\lib\site-packages\torch\nn\parallel\data_parallel.py", line 159, in forward
return self.module(*inputs[0], **kwargs[0])
File "C:\ProgramData\Anaconda3\envs\xray_pytorch\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "C:\Users\sanma\Documents\GitHub\KoSpeech\kospeech\models\deepspeech2\model.py", line 100, in forward
output = rnn_layer(output, output_lengths)
File "C:\ProgramData\Anaconda3\envs\xray_pytorch\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "C:\Users\sanma\Documents\GitHub\KoSpeech\kospeech\models\modules.py", line 103, in forward
output = nn.utils.rnn.pack_padded_sequence(inputs, input_lengths)
File "C:\ProgramData\Anaconda3\envs\xray_pytorch\lib\site-packages\torch\nn\utils\rnn.py", line 244, in pack_padded_sequence
_VF._pack_padded_sequence(input, lengths, batch_first)
RuntimeError: 'lengths' argument should be a 1D CPU int64 tensor, but got 1D cuda:0 Long tensor

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

###370 error
[2021-01-26 00:15:16,965][kospeech.utils][INFO] - step: 380/77500, loss: 4.628228, cer: 1.94, elapsed: 19.49s 13.26m 0.22h, lr: 0.000038
Traceback (most recent call last):
File "main.py", line 170, in main
last_model_checkpoint = train(config)
File "main.py", line 138, in train
model = trainer.train(
File "C:\Users\sanma\Documents\GitHub\KoSpeech\kospeech\trainer\supervised_trainer.py", line 155, in train
model, train_loss, train_cer = self.__train_epoches(
File "C:\Users\sanma\Documents\GitHub\KoSpeech\kospeech\trainer\supervised_trainer.py", line 242, in __train_epoches
output, loss, ctc_loss, cross_entropy_loss = self.model_forward(
File "C:\Users\sanma\Documents\GitHub\KoSpeech\kospeech\trainer\supervised_trainer.py", line 411, in model_forward
output, output_lengths = model(inputs, input_lengths)
File "C:\ProgramData\Anaconda3\envs\xray_pytorch\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "C:\ProgramData\Anaconda3\envs\xray_pytorch\lib\site-packages\torch\nn\parallel\data_parallel.py", line 159, in forward
return self.module(*inputs[0], **kwargs[0])
File "C:\ProgramData\Anaconda3\envs\xray_pytorch\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "C:\Users\sanma\Documents\GitHub\KoSpeech\kospeech\models\deepspeech2\model.py", line 100, in forward
output = rnn_layer(output, output_lengths)
File "C:\ProgramData\Anaconda3\envs\xray_pytorch\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "C:\Users\sanma\Documents\GitHub\KoSpeech\kospeech\models\modules.py", line 103, in forward
cpu_input = input_lengths.to(torch.device('cpu'))
RuntimeError: CUDA error: unspecified launch failure

370 error

[2021-01-26 00:15:16,965][kospeech.utils][INFO] - step: 380/77500, loss: 4.628228, cer: 1.94, elapsed: 19.49s 13.26m 0.22h, lr: 0.000038
Traceback (most recent call last):
File "main.py", line 170, in main
last_model_checkpoint = train(config)
File "main.py", line 138, in train
model = trainer.train(
File "C:\Users\sanma\Documents\GitHub\KoSpeech\kospeech\trainer\supervised_trainer.py", line 155, in train
model, train_loss, train_cer = self.__train_epoches(
File "C:\Users\sanma\Documents\GitHub\KoSpeech\kospeech\trainer\supervised_trainer.py", line 242, in __train_epoches
output, loss, ctc_loss, cross_entropy_loss = self.model_forward(
File "C:\Users\sanma\Documents\GitHub\KoSpeech\kospeech\trainer\supervised_trainer.py", line 411, in model_forward
output, output_lengths = model(inputs, input_lengths)
File "C:\ProgramData\Anaconda3\envs\xray_pytorch\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "C:\ProgramData\Anaconda3\envs\xray_pytorch\lib\site-packages\torch\nn\parallel\data_parallel.py", line 159, in forward
return self.module(*inputs[0], **kwargs[0])
File "C:\ProgramData\Anaconda3\envs\xray_pytorch\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "C:\Users\sanma\Documents\GitHub\KoSpeech\kospeech\models\deepspeech2\model.py", line 100, in forward
output = rnn_layer(output, output_lengths)
File "C:\ProgramData\Anaconda3\envs\xray_pytorch\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "C:\Users\sanma\Documents\GitHub\KoSpeech\kospeech\models\modules.py", line 103, in forward
cpu_input = input_lengths.to(torch.device('cpu'))
RuntimeError: CUDA error: unspecified launch failure

Linked Issues

  • resolved #
@sooftware
Copy link
Owner

딥스피치2 구현상에 사소한 오류가 있었던 것 같습니다.
의심되는 부분을 수정했습니다. 이후에도 학습이 안되면 코멘트 남겨주세요.

@switiz
Copy link
Contributor Author

switiz commented Jan 26, 2021

최신으로 sync후에도 에러가 발생해서 써치를 해보니 pytorch의 에러로 추정됩니다.
pytorch github에 아래 issue가 생성되어있네요.

pytorch/pytorch#43227
jdb78/pytorch-forecasting#135
https://forum.pyro.ai/t/pytorch-rnn-lengths-in-cpu-incompatibility-with-elbo-loss-calculation/2391/5
input_length(cpu) -> out_length(gpu)로 cpu gpu변환을 한번거치는것이 workaround라고 합니다.

수정내역

input_lengths만 수정후 현재 학습중이며 370 step error 재 발생시 다시 comment하도록 하겠습니다.

modules.py

def forward(self, inputs: Tensor, input_lengths: Tensor):
        device = inputs.device
        total_length = inputs.size(0)

        inputs = F.relu(self.batch_norm(inputs.transpose(1, 2)))
        inputs = inputs.transpose(1, 2)

        output = nn.utils.rnn.pack_padded_sequence(inputs, **input_lengths.cpu()**)
        output, hidden = self.rnn(output)
        output, _ = nn.utils.rnn.pad_packed_sequence(output, total_length=total_length)

@sooftware
Copy link
Owner

네 버그 리포트 해주셔서 감사합니다!
이후 결과도 공유해주시면 감사하겠습니다.

추가적으로 코드를 수정해야 된다면, pr 넣어주시면 감사하겠습니다~

sooftware added a commit that referenced this issue Jan 27, 2021
@switiz
Copy link
Contributor Author

switiz commented Jan 27, 2021

수정내역 pr하였습니다.

해당 수정사항이 적용되어도 일정 step이상 진행되면 멈추는 현상은 계속 재현이 되서 찾아보니, 이는 gpu driver와 windows os사이의 문제로 파악됩니다. 일반적으로 모델 연산량 대비 GPU성능이 부족하면 OS에 응답하지 못하는 현상이 있는것 같습니다.
trouble shooting 관련해서는 아래 tracker를 참고 하였습니다.

issue : pytorch/pytorch#27837

  1. TDR settings change
    -> 이슈 여전히 발생함
  2. torch.backends.cudnn.enabled = False
    -> 학습이 되나 성능 30프로 하락

성능하락이 심하고 Driver나 HW dependency가 있을것으로 보여, 해당 내용은 report만 드립니다.

감사합니다.

@switiz switiz closed this as completed Jan 27, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants