Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about the target of NCE loss? #56

Open
oochen opened this issue Nov 10, 2020 · 5 comments
Open

Question about the target of NCE loss? #56

oochen opened this issue Nov 10, 2020 · 5 comments

Comments

@oochen
Copy link

oochen commented Nov 10, 2020

targets = torch.zeros(B * S, dtype=torch.long)
return cross_entropy_loss(predictions, targets)

Why are the targets all zeros? Should be the l_pos be 1?
Thank you.

@Kppeppas
Copy link

That's what I understood at first. However, having a second look, it seems that it is correct. Predictions will be a (B * S, S+1) matrix. The columns hold as their first element the l_pos, while the rest S elements the l_neg. Targets is a B*S vector which points to the target element of these columns. To tell the cross entropy loss that the l_pos is the target element, we need to input 0 for each one of them. See also torch.nn.CrossEntropyLoss, where it explains that it expects a class index in the range [0, C-1] as the target for each value of a 1D tensor of size minibatch.

@Shalomash
Copy link

@Kppeppas That makes sense, I am working on an implementation in Tensorflow. I figured that it had something to do with the way Pytorch handles the loss so in my implementation I had been using [1,0,0...npatches] as the target vector. Thanks for clarifying 👍.

@oochen
Copy link
Author

oochen commented Nov 12, 2020

That's what I understood at first. However, having a second look, it seems that it is correct. Predictions will be a (B * S, S+1) matrix. The columns hold as their first element the l_pos, while the rest S elements the l_neg. Targets is a B*S vector which points to the target element of these columns. To tell the cross entropy loss that the l_pos is the target element, we need to input 0 for each one of them. See also torch.nn.CrossEntropyLoss, where it explains that it expects a class index in the range [0, C-1] as the target for each value of a 1D tensor of size minibatch.

So the target index is 0, which means l_pos?

@Kppeppas
Copy link

@oochen yes, for every patch s in S, the target index is 0 (which is l_pos) from the possible 1+S (l_pos + l_neg) indexes. A confirmation from the authors would be nice.

@xiaosean
Copy link

Thanks, @Kppeppas explained it.

I do some experiment to explain more detail about the patchnce loss

In the intuition, we think the predictions shape is (B*S, 1(pos)+S(neg)),

so we assume the target should be [1, 0, 0, ..., 0]。

but when we trace the explanation of Pytorch doc, See also torch.nn.CrossEntropyLoss

The below text is the official explanation of target in torch.nn.CrossEntropyLoss
image

To brief explanation, the target vector contains the class_idx[0, k-1] of each value. In other words, we want to maximize the column of target class_idx and minimize others.

However, the nceloss means cosine similarity, it is not a classification problem, it just shows the value of the similarity score.

So it does some magic lines.

dim=2 means cat by columes

# calculate logits: (B)x(S)x(S+1)
logits = torch.cat((l_pos, l_neg), dim=2)
predictions = logits.flatten(0, 1)

Some kernel code result shows below

image
image

we want to treat nceloss as classification problem.

it will reshape as

[p, n, n, n, n, n, n, n, n, n]
[o, e, e, e, e, e, e, e, e, e]
[s, g, g, g, g, g, g, g, g, g]

so the next lines,

target setting to zero means we want to maximize the first columns(pos, the zero-index) and others(negs) should be zero.

And that's why we set a zero vector, which means we want to maximize the zero-index(postive).

targets = torch.zeros(B * S, dtype=torch.long)

return cross_entropy_loss(predictions, targets)

I hope this post can help you~

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants