Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

nn.CosineSimilarity returns value larger than 1 #78064

Open
skyelves opened this issue May 22, 2022 · 14 comments
Open

nn.CosineSimilarity returns value larger than 1 #78064

skyelves opened this issue May 22, 2022 · 14 comments
Assignees
Labels
module: correctness (silent) issue that returns an incorrect result silently module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@skyelves
Copy link

skyelves commented May 22, 2022

馃悰 Describe the bug

nn.CosineSimilarity returns value larger than 1

When I was computing cosine similarity, it returned a tensor([1.0000]). However, it's larger than 1, which leads to the runtimeError of BCELoss.

To reproduce the bug

import torch
import torch.nn as nn
t1 = torch.tensor([[1.6965e-02, 0.0000e+00, 1.5725e-02, 0.0000e+00, 9.7518e-03, 4.1566e-03,
         2.8437e-03, 1.2394e-03, 0.0000e+00, 4.4327e-02, 6.6013e-02, 2.3693e-02,
         1.2146e-02, 9.4390e-03, 0.0000e+00, 2.4374e-02, 0.0000e+00, 0.0000e+00,
         9.9630e-04, 8.2091e-03, 8.6477e-05, 0.0000e+00, 1.2825e-02, 0.0000e+00,
         1.5316e-03, 0.0000e+00, 4.4526e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 9.7517e-04, 3.3356e-02, 8.4023e-08, 5.8102e-04, 0.0000e+00,
         2.3170e-02, 0.0000e+00, 0.0000e+00, 7.8518e-03, 0.0000e+00, 1.9662e-02,
         7.7019e-05, 1.7013e-02, 4.0341e-02, 3.7943e-03, 2.0059e-02, 1.6905e-02,
         0.0000e+00, 0.0000e+00, 3.3092e-02, 0.0000e+00, 2.0570e-04, 6.7327e-03,
         0.0000e+00, 0.0000e+00, 8.5911e-04, 0.0000e+00, 0.0000e+00, 1.9356e-02,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 3.9724e-02]])
t2 = torch.tensor([[1.6965e-02, 0.0000e+00, 1.5725e-02, 0.0000e+00, 9.7522e-03, 4.1569e-03,
         2.8436e-03, 1.2394e-03, 0.0000e+00, 4.4329e-02, 6.6014e-02, 2.3694e-02,
         1.2146e-02, 9.4390e-03, 0.0000e+00, 2.4375e-02, 0.0000e+00, 0.0000e+00,
         9.9659e-04, 8.2090e-03, 8.6500e-05, 0.0000e+00, 1.2826e-02, 0.0000e+00,
         1.5317e-03, 0.0000e+00, 4.4532e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 9.7523e-04, 3.3357e-02, 8.3033e-08, 5.8104e-04, 0.0000e+00,
         2.3171e-02, 0.0000e+00, 0.0000e+00, 7.8521e-03, 0.0000e+00, 1.9662e-02,
         7.7023e-05, 1.7013e-02, 4.0342e-02, 3.7944e-03, 2.0059e-02, 1.6906e-02,
         0.0000e+00, 0.0000e+00, 3.3093e-02, 0.0000e+00, 2.0572e-04, 6.7329e-03,
         0.0000e+00, 0.0000e+00, 8.5914e-04, 0.0000e+00, 0.0000e+00, 1.9357e-02,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 3.9725e-02]])
print(t1.size(), t2.size())
pred = nn.CosineSimilarity(dim=1, eps=1e-8)
cos = pred(t1, t2)
print(cos, cos>1)
criterion = nn.BCELoss()
res = criterion(cos, torch.ones(1))
print(res)

It produces the RuntimeError:

torch.Size([1, 64]) torch.Size([1, 64])
tensor([1.0000]) tensor([True])
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-52-beebcac695ff> in <module>
      3 print(cos, cos>1)
      4 criterion = nn.BCELoss()
----> 5 res = criterion(cos, torch.ones(1))
      6 print(res)

~/anaconda3/envs/patent-isic/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/envs/patent-isic/lib/python3.7/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
    601 
    602     def forward(self, input: Tensor, target: Tensor) -> Tensor:
--> 603         return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
    604 
    605 

~/anaconda3/envs/patent-isic/lib/python3.7/site-packages/torch/nn/functional.py in binary_cross_entropy(input, target, weight, size_average, reduce, reduction)
   2913         weight = weight.expand(new_size)
   2914 
-> 2915     return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum)
   2916 
   2917 

RuntimeError: all elements of input should be between 0 and 1

Versions

Collecting environment information...
PyTorch version: 1.10.2
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 11.6 (x86_64)
GCC version: Could not collect
Clang version: 12.0.0
CMake version: version 3.15.5
Libc version: N/A

Python version: 3.7.5 (default, Oct 25 2019, 10:52:18) [Clang 4.0.1 (tags/RELEASE_401/final)] (64-bit runtime)
Python platform: Darwin-20.6.0-x86_64-i386-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.21.5
[pip3] torch==1.10.2
[pip3] torch-cluster==1.6.0
[pip3] torch-geometric==2.0.4
[pip3] torch-scatter==2.0.9
[pip3] torch-sparse==0.6.13
[pip3] torch-spline-conv==1.2.1
[conda] blas 1.0 mkl
[conda] mkl 2019.0 pypi_0 pypi
[conda] mkl-service 2.3.0 py37h9ed2024_1
[conda] mkl_fft 1.3.0 py37h4a7008c_2
[conda] mkl_random 1.2.1 py37hb2f4e1b_2
[conda] numpy 1.17.3 pypi_0 pypi
[conda] numpy-base 1.21.5 py37h3b1a694_1
[conda] pyg 2.0.4 py37_torch_1.10.0_cpu pyg
[conda] pytorch 1.10.2 cpu_py37h903acac_0
[conda] pytorch-cluster 1.6.0 py37_torch_1.10.0_cpu pyg
[conda] pytorch-scatter 2.0.9 py37_torch_1.10.0_cpu pyg
[conda] pytorch-sparse 0.6.13 py37_torch_1.10.0_cpu pyg
[conda] pytorch-spline-conv 1.2.1 py37_torch_1.10.0_cpu pyg
[conda] torch 1.4.0 pypi_0 pypi
[conda] torch-geometric 2.0.5 pypi_0 pypi

cc @albanD @mruberry @jbschlosser @walterddr @kshitij12345

@ngimel
Copy link
Collaborator

ngimel commented May 22, 2022

cc @nikitaved

@mrshenli mrshenli added module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: correctness (silent) issue that returns an incorrect result silently labels May 22, 2022
@nikitaved
Copy link
Collaborator

nikitaved commented May 22, 2022

We have made a patch for cosine similarity which, I guess, is going to get exposed in the upcoming release. Could you please try with a nightly release and/or the current master?
EDIT: I can confirm this issue is present on master for float but not for double.
Looks like computing a squared norm + sqrt is not the same as calling norm, as this one works just fine:

In [12]: ll = ((t1 / t1.norm()) * (t2 / t2.norm())).sum()

In [13]: ll > 1
Out[13]: tensor(False)

I will try to come up with an update that uses norm, as it handles reductions much better in terms of precision so it seems, as precision is being lost in the sqrt step...

@skyelves
Copy link
Author

We have made a patch for cosine similarity which, I guess, is going to get exposed in the upcoming release. Could you please try with a nightly release and/or the current master? EDIT: I can confirm this issue is present on master for float but not for double. Looks like computing a squared norm + sqrt is not the same as calling norm, as this one works just fine:

In [12]: ll = ((t1 / t1.norm()) * (t2 / t2.norm())).sum()

In [13]: ll > 1
Out[13]: tensor(False)

I will try to come up with an update that uses norm, as it handles reductions much better in terms of precision so it seems, as precision is being lost in the sqrt step...

Yeah, thanks for pointing out the problem. Do you have any suggestions to solve that or get rid of this?

@ngimel
Copy link
Collaborator

ngimel commented May 23, 2022

@nikitaved note that current cosine_similarity implementation breaks #18057 again, probably we should just do what scipy does, the link's in that issue

@nikitaved
Copy link
Collaborator

nikitaved commented May 23, 2022

@ngimel , I have mentioned that in my post. Note, that using norm instead does not fail. The reason why it is more stable, norm appears to use double as accumulation type and it is on double that sqrt is run on. The solution on master runs sqrt on a float tensor, and that is why it loses precision I suppose. These two solutions, norm or not norm solve different problems. norm is the most stable one in forward, but does not handle zeroes in backward for close to-zero-norm inputs. Maybe we could implement cosine_similarity as a non-composite function with a custom backward? What SciPy is doing is worse, this is exactly the previous behavior the OP complaints about.

@nikitaved
Copy link
Collaborator

nikitaved commented May 23, 2022

@skyelves , if you do not expect to backprop through inputs with close to zero norms, you could try the code sample from my very first message, adapted to your case (i.e. dim).

@ngimel
Copy link
Collaborator

ngimel commented May 23, 2022

Scipy produces correct cosine similarity on equal inputs, which is currently broken, so I wouldn't call it "worse". Also, on cuda norm is not accumulated in double, and sqrt isn't run on double, so I'm not sure that argument applies. Also #65815 indicates that even on CPU norm isn't run in double.

@nikitaved
Copy link
Collaborator

nikitaved commented May 23, 2022

@ngimel , the SciPy implementation is what is implemented in PyTorch 1.10. That one was failing for other inputs. I am just not sure we can guarantee stable behavior for all types of inputs. For example, the SciPy implementation will suffer from precision degradation for inputs with large norms.

@ngimel
Copy link
Collaborator

ngimel commented May 23, 2022

I understand (although 1.10 didn't implement scipy exactly, it was still using sums, not means, but currently we traded one set of inputs for another set of inputs, so it's not strictly better.

@nikitaved
Copy link
Collaborator

nikitaved commented May 23, 2022

Oh, that is a good point. We can indeed scale inputs prior to norm computation, as this similarity is homogeneous in each argument. Maybe this will solve the issue, I will give it a shot.

@ngimel
Copy link
Collaborator

ngimel commented May 23, 2022

But still, maybe at the end of the day we should just clamp outs so that they are no bigger than one. Gradient computation is a different matter, so still having accurate forward matters (as that clamping would have to be in no autograd context, as it is now)

@skyelves
Copy link
Author

@skyelves , if you do not expect to backprop through inputs with close to zero norms, you could try the code sample from my very first message, adapted to your case (i.e. dim).

Actually I need to backprop so as to train the NN.

@subercui
Copy link

Hi, I am encountering the same issue and also need to backdrop. I wonder has this one been solved in any recent release, or is there a workaround so far?

@nikitaved nikitaved self-assigned this Jul 21, 2022
@nikitaved
Copy link
Collaborator

@subercui , not yet. I am submitting a fix now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: correctness (silent) issue that returns an incorrect result silently module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants