Skip to content

Inconsistent results between CPU and GPU for many operators with complex inputs containing Inf #141487

@rookieLiu2018

Description

@rookieLiu2018

🐛 Describe the bug

When performing computations on complex tensors that include Inf as the real or imaginary component, a wide range of PyTorch operators produce inconsistent results between CPU and GPU. This issue affects fundamental math functions such as torch.sin, torch.cos, torch.acos, torch.tanh, and others.

MRE:

import torch

test_inputs = [
    torch.tensor([complex(torch.inf, 0), complex(0, torch.inf), complex(torch.inf, torch.inf)], dtype=torch.complex128),
]
test_apis = [
    torch.sin, torch.cos, torch.tan, torch.acos, torch.asin, torch.atan,
    torch.sinh, torch.cosh, torch.tanh, torch.exp, torch.rsqrt, torch.mean
]

for api in test_apis:
    print(f"Testing {api.__name__}")
    for x in test_inputs:
        try:
            cpu_out = api(x)
            gpu_out = api(x.cuda())
            print(f"CPU Output: {cpu_out}")
            print(f"GPU Output: {gpu_out}")
        except Exception as e:
            print(f"Error in {api.__name__}: {e}")

the output is:

Testing sin
CPU Output: tensor([nan+nanj, inf+infj, inf+infj], dtype=torch.complex128)
GPU Output: tensor([nan+0.j, 0.+infj, nan+infj], device='cuda:0', dtype=torch.complex128)
Testing cos
CPU Output: tensor([nan+nanj, inf-0.j, inf-infj], dtype=torch.complex128)
GPU Output: tensor([nan-0.j, inf+0.j, inf+nanj], device='cuda:0', dtype=torch.complex128)
Testing tan
CPU Output: tensor([nan+nanj, 0.+1.j, nan+nanj], dtype=torch.complex128)
GPU Output: tensor([nan+nanj, 0.+1.j, 0.+1.j], device='cuda:0', dtype=torch.complex128)
Testing acos
CPU Output: tensor([nan+nanj, nan+nanj, nan+nanj], dtype=torch.complex128)
GPU Output: tensor([0.0000-infj, 1.5708-infj, 0.7854-infj], device='cuda:0',
       dtype=torch.complex128)
Testing asin
CPU Output: tensor([nan+nanj, nan+nanj, nan+nanj], dtype=torch.complex128)
GPU Output: tensor([1.5708+infj, 0.0000+infj, 0.7854+infj], device='cuda:0',
       dtype=torch.complex128)
Testing atan
CPU Output: tensor([nan+nanj, nan+nanj, nan+nanj], dtype=torch.complex128)
GPU Output: tensor([1.5708+0.j, 1.5708+0.j, 1.5708+0.j], device='cuda:0',
       dtype=torch.complex128)
Testing sinh
CPU Output: tensor([inf+infj, nan+nanj, inf+infj], dtype=torch.complex128)
GPU Output: tensor([inf+0.j, 0.+nanj, inf+nanj], device='cuda:0', dtype=torch.complex128)
Testing cosh
CPU Output: tensor([inf+0.j, nan+nanj, inf+infj], dtype=torch.complex128)
GPU Output: tensor([inf+0.j, nan+0.j, inf+nanj], device='cuda:0', dtype=torch.complex128)
Testing tanh
CPU Output: tensor([1.+0.j, nan+nanj, nan+nanj], dtype=torch.complex128)
GPU Output: tensor([1.+0.j, nan+nanj, 1.+0.j], device='cuda:0', dtype=torch.complex128)
Testing exp
CPU Output: tensor([inf+nanj, nan+nanj, nan+nanj], dtype=torch.complex128)
GPU Output: tensor([inf+0.j, nan+nanj, inf+nanj], device='cuda:0', dtype=torch.complex128)
Testing rsqrt
CPU Output: tensor([0.+0.j, nan+nanj, nan+nanj], dtype=torch.complex128)
GPU Output: tensor([0.+0.j, 0.-0.j, 0.-0.j], device='cuda:0', dtype=torch.complex128)
Testing mean
CPU Output: (nan+nanj)
GPU Output: (inf+infj)

Versions

Collecting environment information...
PyTorch version: 2.4.1+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 10
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A

Python version: 3.8.8rc1 (tags/v3.8.8rc1:dfd7d68, Feb 17 2021, 11:01:21) [MSC v.1928 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.19041-SP0
Is CUDA available: True
CUDA runtime version: 12.1.66
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3070
Nvidia driver version: 560.94
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Name: AMD Ryzen 7 5700X 8-Core Processor
Manufacturer: AuthenticAMD
Family: 107
Architecture: 9
ProcessorType: 3
DeviceID: CPU0
CurrentClockSpeed: 3401
MaxClockSpeed: 3401
L2CacheSize: 4096
L2CacheSpeed: None
Revision: 8450

Versions of relevant libraries:
[pip3] numpy==1.24.4
[pip3] torch==2.4.1+cu121
[pip3] torchaudio==2.4.1+cu121
[pip3] torchvision==0.19.1+cu121
[conda] _anaconda_depends         2023.09             py311_mkl_1
[conda] blas                      1.0                         mkl
[conda] mkl                       2023.1.0         h6b88ed4_46357
[conda] mkl-service               2.4.0           py311h2bbff1b_1
[conda] mkl_fft                   1.3.8           py311h2bbff1b_0
[conda] mkl_random                1.2.4           py311h59b6b97_0
[conda] numpy                     1.24.3          py311hdab7c0b_1
[conda] numpy-base                1.24.3          py311hd01c5d8_1
[conda] numpydoc                  1.5.0           py311haa95532_0
[conda] torch                     2.1.0                    pypi_0    pypi

cc @ptrblck @msaroufim @eqy @ezyang @anjali411 @dylanbespalko @mruberry @nikitaved @amjames @malfet @janeyx99 @mingfeima @jiayisunx

Metadata

Metadata

Assignees

Labels

module: complexRelated to complex number support in PyTorchmodule: cudaRelated to torch.cuda, and CUDA support in generalmodule: edge casesAdversarial inputs unlikely to occur in practicemodule: numerical-stabilityProblems related to numerical stability of operationstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

Status

Done

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions