Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

您好问个问题,参数里没分类个数 #1

Open
smallliang opened this issue Dec 29, 2018 · 13 comments
Open

您好问个问题,参数里没分类个数 #1

smallliang opened this issue Dec 29, 2018 · 13 comments

Comments

@smallliang
Copy link

应该在哪添加啊,我看from_pretrained代码里没找到

@real-brilliant
Copy link
Owner

不需要分类个数 在'bert.py'的‘MyPro() - get_labels’里,直接把'return [0, 1]'改成你的类别名列表

@li-cheng12
Copy link

li-cheng12 commented Jan 25, 2019

我也能遇到了类似的问题,我把return[0,1]改成我的类别列表后,报了这个错,RuntimeError: CUDA error: device-side assert triggered。去网上查了下,是说类别的数量必须在0到n_classes之间,但是n_classes(分类个数)不知道在什么地方设置。我的类别数量有4000+个,请问有遇到类似的问题吗

@badbubble
Copy link

@licheng-pro
如果你数据中的label已经转成了label id:

def get_labels(self):
    return [str(i) for i in range(n_classes)]

@li-cheng12
Copy link

@ETCartman 没有转成label id,
image
并且当我换成cpu后,报了这个错。RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed. at /pytorch/aten/src/THNN/generic/ClassNLLCriterion.c:93

@badbubble
Copy link

@licheng-pro 需要return是lable id,而不是label!
try this:

def get_labels(self, label_path):
    fo = open(label_path, 'r', encoding='utf-8')
    lines = fo.readlines()
    label_to_id ={}
    for i, line in enumerate(lines):
        label_to_id[line.strip()] = i
    print(label_to_id)
    fo.close()
    return [str(i) for i in range(len(label_to_id))]

or just:

def get_labels(self, label_path):
    fo = open(label_path, 'r', encoding='utf-8')
    lines = fo.readlines()
    fo.close()
    return [str(i) for i in range(len(lines))]

@li-cheng12
Copy link

li-cheng12 commented Jan 25, 2019

@ETCartman
convert_examples_to_features方法里的这个地方才是把label变成label id吧,getlabels方法返回的应该就是label吧?
image

@badbubble
Copy link

@licheng-pro 刚去看了下代码确实是:stuck_out_tongue:,不过我做2k个类别没有任何问题。
RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed. at /pytorch/aten/src/THNN/generic/ClassNLLCriterion.c:93
我个人感觉这个错误肯定是label出了问题,比如label_path中的数据有重复, 你可以
return list(set(label_list))
试一下

@li-cheng12
Copy link

我已经解决了,多谢了,老哥 @ETCartman

@liuyijiang1994
Copy link

不过现在好像在

model = BertForSequenceClassification.from_pretrained(args.bert_model,
                                                          cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(
                                                              args.local_rank))

这一行会报如下的错误:

File "/root/anaconda3/envs/liu37/lib/python3.7/site-packages/pytorch_pretrained_bert/modeling.py", line 581, in from_pretrained
   model = cls(config, *inputs, **kwargs)
TypeError: __init__() missing 1 required positional argument: 'num_labels'

@Zhaohaoran1997
Copy link

不过现在好像在

model = BertForSequenceClassification.from_pretrained(args.bert_model,
                                                          cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(
                                                              args.local_rank))

这一行会报如下的错误:

File "/root/anaconda3/envs/liu37/lib/python3.7/site-packages/pytorch_pretrained_bert/modeling.py", line 581, in from_pretrained
   model = cls(config, *inputs, **kwargs)
TypeError: __init__() missing 1 required positional argument: 'num_labels'

我也是这个问题,请问你解决了吗?

@liuyijiang1994
Copy link

@Zhaohaoran1997 在最后加上这个参数就可以

model = BertForSequenceClassification.from_pretrained(args.bert_model,
                                                          cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(
                                                              args.local_rank), num_labels=len(label_list))

也可以参考我fork的代码:https://github.com/liuyijiang1994/bert_senta

@Zhaohaoran1997
Copy link

@Zhaohaoran1997 在最后加上这个参数就可以

model = BertForSequenceClassification.from_pretrained(args.bert_model,
                                                          cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(
                                                              args.local_rank), num_labels=len(label_list))

也可以参考我fork的代码:https://github.com/liuyijiang1994/bert_senta

我改了num_labels之后程序被kill了,请问是我写得有bug还是机器性能不足?

04/22/2019 15:58:31 - INFO - pytorch_pretrained_bert.modeling - Weights of BertForSequenceClassification not initialized from pretrained model: ['classifier.bias', 'classifier.weight'] 04/22/2019 15:58:31 - INFO - pytorch_pretrained_bert.modeling - Weights from pretrained model not used in BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias'] 04/22/2019 15:58:34 - INFO - __main__ - ***** Running training ***** 04/22/2019 15:58:34 - INFO - __main__ - Num examples = 43425 04/22/2019 15:58:34 - INFO - __main__ - Batch size = 128 04/22/2019 15:58:34 - INFO - __main__ - Num steps = 3392 Epoch: 0%| | 0/10 [00:00<?, ?it/s]已杀死ion: 0%| | 0/340 [00:00<?, ?it/s]

@liuyijiang1994
Copy link

@Zhaohaoran1997 看上去运行的时候已经是正常的了 也许是性能的问题

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants