You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
If the batch size is 64, the labels variable have a shape of [64]
When the above code performs, ([64, 1] == [1, 64]).float() → [64, 64], which is exact 2D diagonal matrix.
mask=mask-torch.diag(torch.diag(mask))
But the problem is on the second line of code.
If torch.diag(mask) performs, the result has a shape of [64] that is one-filled vector: $[1, 1, 1, ...]$
Therefore, the result of torch.diag(torch.diag(mask)) is exactly same with the mask, which is exact 2D diagonal matrix.
Furthermore, if you subtract the result from mask, eventually the mask is always zero-filled matrix.
Eventually, the mask variable have no power for gradient descending.
Is this really on your purpose?
I thought the mask variable is used for distinguishing $P(i)$ and $N(i)$ in equation.
Is this right? Or am I missing a point?
The text was updated successfully, but these errors were encountered:
Reference: https://github.com/wzhouad/Contra-OOD/blob/main/model.py#L40
I have question about the implementation of Margin-based Contrastive Loss
If the batch size is 64, the
labels
variable have a shape of[64]
When the above code performs,
([64, 1] == [1, 64]).float()
→[64, 64]
, which is exact 2D diagonal matrix.But the problem is on the second line of code.$[1, 1, 1, ...]$
If
torch.diag(mask)
performs, the result has a shape of[64]
that is one-filled vector:Therefore, the result of
torch.diag(torch.diag(mask))
is exactly same with themask
, which is exact 2D diagonal matrix.Furthermore, if you subtract the result from
mask
, eventually themask
is always zero-filled matrix.Eventually, the
mask
variable have no power for gradient descending.Is this really on your purpose?
I thought the$P(i)$ and $N(i)$ in equation.
mask
variable is used for distinguishingIs this right? Or am I missing a point?
The text was updated successfully, but these errors were encountered: