-
Notifications
You must be signed in to change notification settings - Fork 43
Closed
Description
Hello, again.
I'm studying your paper and code.
However, in following codes in your 'cost.py' file,
l_pos = torch.einsum('nc,nc->n', [a1, a2]).unsqueeze(-1)
# negative logits: NxK
l_neg = torch.einsum('nc,ck->nk', [a1, a2_neg])
# logits: Nx(1+K)
logits = torch.cat([l_pos, l_neg], dim=1)
# apply temperature
logits /= T
# labels: positive key indicators - first dim of each batch
labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
loss = F.cross_entropy(logits, labels)
I think that one of the instances (N) is 1.. because when crossentropy is calculated, the positive one's label become 1.
I don't know well, so I want your advice. Thank you.
Metadata
Metadata
Assignees
Labels
No labels