Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bart长文本训练问题 #14

Open
YoungChanYY opened this issue Apr 11, 2023 · 5 comments
Open

Bart长文本训练问题 #14

YoungChanYY opened this issue Apr 11, 2023 · 5 comments
Labels
question Further information is requested wontfix This will not be worked on

Comments

@YoungChanYY
Copy link

YoungChanYY commented Apr 11, 2023

我用Bart训练代码,每个训练数据都为:输入文本约1000字符,输出文本长3-5万字符。训练几个epoch后会出错,错误信息如下所示。
但是控制输入和输出的字符长度,比如都为100字符左右,则训练正常,没有报错。

请问一下:Bart模型的输入输出长度有什么要求吗,这应该是内部embedding维度出错了吧。谢谢。

RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling cublasSgemmStridedBatched( handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches)

@YoungChanYY YoungChanYY added the question Further information is requested label Apr 11, 2023
@YoungChanYY
Copy link
Author

出错的位置好像是在predict位置。当取消在训练过程中进行eval处理时,训练得以正常进行。大佬

Traceback (most recent call last):
File "train_bart_text2abc.py", line 180, in
main()
File "train_bart_text2abc.py", line 163, in main
model.train_model(train_df, eval_data=eval_df, split_on_space=True, matches=count_matches)
File "textgen/seq2seq/bart_seq2seq_model.py", line 452, in train_model
**kwargs,
File "textgen/seq2seq/bart_seq2seq_model.py", line 983, in train
**kwargs,
File "textgen/seq2seq/bart_seq2seq_model.py", line 1153, in eval_model
preds = self.predict(to_predict, split_on_space=split_on_space)
File "textgen/seq2seq/bart_seq2seq_model.py", line 1310, in predict
num_return_sequences=self.args.num_return_sequences,
File "/usr/local/lib/python3.7/dist-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/transformers/generation/utils.py", line 1400, in generate
**model_kwargs,
File "/usr/local/lib/python3.7/dist-packages/transformers/generation/utils.py", line 2183, in greedy_search
output_hidden_states=output_hidden_states,
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/transformers/models/bart/modeling_bart.py", line 1389, in forward
return_dict=return_dict,
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/transformers/models/bart/modeling_bart.py", line 1268, in forward
return_dict=return_dict,
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/transformers/models/bart/modeling_bart.py", line 1124, in forward
use_cache=use_cache,
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/transformers/models/bart/modeling_bart.py", line 431, in forward
output_attentions=output_attentions,
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/transformers/models/bart/modeling_bart.py", line 275, in forward
attn_output = torch.bmm(attn_probs, value_states)
RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling cublasSgemmStridedBatched( handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches)
../aten/src/ATen/native/cuda/Indexing.cu:650: indexSelectSmallIndex: block: [0,0,0], thread: [0,0,0] Assertion srcIndex < srcSelectDimSize failed.

@shibing624
Copy link
Owner

我看看evaluate的逻辑

@YoungChanYY
Copy link
Author

YoungChanYY commented Apr 12, 2023

多谢。

我看到另一处地方,应该有些问题:
在textgen/seq2seq/bart_seq2seq_utils.py的preprocess_data_bart(data)函数中,对target_ids 数据处理的问题和建议如下,大佬看看对不对。谢谢!

def preprocess_data_bart(data):
input_text, target_text, tokenizer, args = data
......
target_ids = tokenizer.batch_encode_plus(
[target_text],
# max_length=args.max_seq_length, #原代码
max_length=args.max_length, #建议代码
padding="max_length",
return_tensors="pt",
truncation=True,
)

@shibing624
Copy link
Owner

对的,fixed: 7a0be59

Copy link

stale bot commented Dec 27, 2023

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.(由于长期不活动,机器人自动关闭此问题,如果需要欢迎提问)

@stale stale bot added the wontfix This will not be worked on label Dec 27, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested wontfix This will not be worked on
Projects
None yet
Development

No branches or pull requests

2 participants