Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about threshold of mask in baseline #28

Closed
Dirtybluer opened this issue Nov 17, 2020 · 1 comment
Closed

Question about threshold of mask in baseline #28

Dirtybluer opened this issue Nov 17, 2020 · 1 comment

Comments

@Dirtybluer
Copy link

I reviewed the code history and found the commit postprocess Mask by HardThreshold.

As far as I understand, this is supposed to be the baseline described in the paper, which I'm not quite sure though.

One thing I found a bit confusing for me is that the threshold for mask head (i.e. for Masker) is set as 0.01 here. Shouldn't it be 0.5 after applying sigmoid()?

I've noticed that you moved sigmoid() from post-process to predictor. However, I suppose that won't change values feeding into Masker, right? Also, I'd like to know why such a move with sigmoid() is necessary?

Looking forward to your reply! @JingChaoLiu @liuxuebo0

@JingChaoLiu
Copy link
Collaborator

这次commit是对PMTD中的plane clustering进行ablation study。模型还是使用pyramid label训练的模型,输出还是pyramid mask。但是框的预测方式并不是plane clustering,而是直接阈值截断。原则上阈值截断的边界应该是 pyramid mask的z值为0,但是为了鲁棒,我们把z的阈值改成了0.01。也就是说,这里的0.01是表达“通过阈值截断的方式获取文字边界”。

至于为什么sigmoid要从后处理移到主模型,原因是这样的:
当你做2分类任务时,
model.train()时,损失函数binary cross entropy实际对pred干了两件事:
pred = pred.sigmoid(); loss = -(gt * log pred + (1- gt) * log (1-pred))
而当model.eval()时,后处理对pred做了一件事:
pred = pred.sigmoid()

但是当你尝试回归pyramid mask \in [0, 1] 时,损失函数变成了L1 loss:
model.train()时,原来binary cross entropy帮你做的pred = pred.sigmoid(),现在要自己做。
而当model.eval()时,由于主模型里有了pred = pred.sigmoid(),所以后处理就不需要再次调用sigmoid了。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants