-
Notifications
You must be signed in to change notification settings - Fork 409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Question about the target of NCE loss? #56
Comments
That's what I understood at first. However, having a second look, it seems that it is correct. Predictions will be a (B * S, S+1) matrix. The columns hold as their first element the l_pos, while the rest S elements the l_neg. Targets is a B*S vector which points to the target element of these columns. To tell the cross entropy loss that the l_pos is the target element, we need to input 0 for each one of them. See also torch.nn.CrossEntropyLoss, where it explains that it expects a class index in the range [0, C-1] as the target for each value of a 1D tensor of size minibatch. |
@Kppeppas That makes sense, I am working on an implementation in Tensorflow. I figured that it had something to do with the way Pytorch handles the loss so in my implementation I had been using [1,0,0...npatches] as the target vector. Thanks for clarifying 👍. |
So the target index is 0, which means l_pos? |
@oochen yes, for every patch s in S, the target index is 0 (which is l_pos) from the possible 1+S (l_pos + l_neg) indexes. A confirmation from the authors would be nice. |
Thanks, @Kppeppas explained it. I do some experiment to explain more detail about the patchnce loss In the intuition, we think the predictions shape is (B*S, 1(pos)+S(neg)), so we assume the target should be [1, 0, 0, ..., 0]。 but when we trace the explanation of Pytorch doc, See also torch.nn.CrossEntropyLoss The below text is the official explanation of target in torch.nn.CrossEntropyLoss To brief explanation, the target vector contains the class_idx[0, k-1] of each value. In other words, we want to maximize the column of target class_idx and minimize others. However, the nceloss means cosine similarity, it is not a classification problem, it just shows the value of the similarity score. So it does some magic lines. dim=2 means cat by columes# calculate logits: (B)x(S)x(S+1)
logits = torch.cat((l_pos, l_neg), dim=2)
predictions = logits.flatten(0, 1) Some kernel code result shows below we want to treat nceloss as classification problem. it will reshape as [p, n, n, n, n, n, n, n, n, n] so the next lines, target setting to zero means we want to maximize the first columns(pos, the zero-index) and others(negs) should be zero. And that's why we set a zero vector, which means we want to maximize the zero-index(postive). targets = torch.zeros(B * S, dtype=torch.long) return cross_entropy_loss(predictions, targets) I hope this post can help you~ |
Why are the targets all zeros? Should be the l_pos be 1?
Thank you.
The text was updated successfully, but these errors were encountered: