-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch.norm produces incorrect results #20551
Comments
i think this is float precision issues. in float64, it seems to be working fine:
|
cc: @umanwizard to double-check, i think you made norm TensorIterator compatible if I remember (sorry if it wasn't you) |
No, it was done by @jjsjann123 in #15414 |
The code above gives correct results in pytorch 0.4, but not in pytorch 1.1.0. If it is a float precision issue, not sure why it works correctly in pytorch 0.4 |
PyTorch 0.4 did the accumulation using double https://github.com/pytorch/pytorch/blob/v0.4.1/aten/src/TH/generic/THTensorMath.cpp#L4307 Now it's using float accumulation: CUDA uses float accumulation, but is saved because the necessary parallelism forces a form of pairwise summation. We should probably do the same thing for CPU. |
Like Sam mentioned, numerical behaviors are very differently on CPU/GPU because of the level of parallelism. Sacrifices like this (using |
@jjsjann123 I was suggesting something different: that we use pairwise summation on the CPU in reduction kernels -- not that we switch to double accumulation. |
This is still an issue. |
This still exists. Anyone take a look? or change the aten operator of torch cpu backend. |
馃悰 Bug
torch.norm gives incorrect results on CPU in the latest nightly build as well as in 1.1.0 stable.
To Reproduce
Expected behavior
Both b and c should have the same values.
Environment
PyTorch version: 1.1.0.dev20190514
Is debug build: No
CUDA used to build PyTorch: 9.0.176
OS: Red Hat Enterprise Linux Server release 7.4 (Maipo)
GCC version: (GCC) 4.8.5 20150623 (Red Hat 4.8.5-16)
CMake version: version 2.8.12.2
Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: Tesla K40m
GPU 1: Tesla K40m
Nvidia driver version: 387.26
cuDNN version: Could not collect
Versions of relevant libraries:
[pip3] msgpack-numpy==0.4.1
[pip3] numpy==1.14.3
[pip3] torch==0.4.0
[pip3] torchtext==0.2.3
[pip3] torchvision==0.2.1
[conda] blas 1.0 mkl
[conda] mkl 2018.0.2 1
[conda] mkl_fft 1.0.1 py36h3010b51_0
[conda] mkl_random 1.0.1 py36h629b387_0
[conda] pytorch-nightly 1.1.0.dev20190514 py3.6_cuda9.0.176_cudnn7.5.1_0 pytorch
[conda] torchtext 0.2.3
Additional context
The text was updated successfully, but these errors were encountered: