We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
因为在CNN/train.py中调用 test(testSet,"Test",e)函数时,报 错误,无法对测试数据进行验证,我修改了 CNN/src/trainer.py文件中train_RE函数,输出预测结果,发现pred都是负值 pred tensor([[-15.4812], [-15.5336], [-15.7833], [-15.1775], [-15.8235], [-15.9747], [-15.4162], [-15.6784], [-15.5847], [-15.0997], [-15.4924],, 这个预测结果好像不对。同时计算loss时,为什么 loss=-torch.sum(pred.view(lEn.size(0)))要计算负值?
具体修改代码如下: def train_RE(self,wordsEn,pos1En,pos2En,rEn,lEn,wordsZh,pos1Zh,pos2Zh,rZh,lZh,re_mask): self.model.train() pred=self.model(wordsEn,pos1En,pos2En,rEn,lEn,wordsZh,pos1Zh,pos2Zh,rZh,lZh,re_mask) print(pred,re_mask)
loss=-torch.sum(pred.view(lEn.size(0)))+Orth_Coef*self.model.Orth_con(wordsEn,pos1En,pos2En,wordsZh,pos1Zh,pos2Zh) if (loss!=loss).data.any(): print("NaN Loss (training RE)") exit() self.optim.zero_grad() loss.backward() self.optim.step() return loss.item()
The text was updated successfully, but these errors were encountered:
这里负值是没有问题的,看models.py可以看到我们输出时是做了log_softmax的,而softmax后再取log当然是负的。计算loss时可以参看我们论文的公式(13),本来就是要取负号。事实上这个公式就是常见的交叉熵loss,如果不理解可以查阅相关资料。
Sorry, something went wrong.
No branches or pull requests
因为在CNN/train.py中调用 test(testSet,"Test",e)函数时,报 错误,无法对测试数据进行验证,我修改了
CNN/src/trainer.py文件中train_RE函数,输出预测结果,发现pred都是负值
pred
tensor([[-15.4812],
[-15.5336],
[-15.7833],
[-15.1775],
[-15.8235],
[-15.9747],
[-15.4162],
[-15.6784],
[-15.5847],
[-15.0997],
[-15.4924],,
这个预测结果好像不对。同时计算loss时,为什么 loss=-torch.sum(pred.view(lEn.size(0)))要计算负值?
具体修改代码如下:
def train_RE(self,wordsEn,pos1En,pos2En,rEn,lEn,wordsZh,pos1Zh,pos2Zh,rZh,lZh,re_mask):
self.model.train()
pred=self.model(wordsEn,pos1En,pos2En,rEn,lEn,wordsZh,pos1Zh,pos2Zh,rZh,lZh,re_mask)
print(pred,re_mask)
The text was updated successfully, but these errors were encountered: