-
Notifications
You must be signed in to change notification settings - Fork 25.5k
Description
🐛 Bug
torch.argmax()
returning different indices to numpy.argmax()
when the element values are the same.
To Reproduce
Code:
import torch
import numpy as np
arr1 = np.array([1,2,3])
arr2 = np.array([5,10,3])
my_list_arr = [arr1, arr2]
A = np.transpose(my_list_arr)
A
Output:
array([[ 1, 5],
[ 2, 10],
[ 3, 3]])
Code:
tsr1 = torch.tensor(np.array([1,2,3]))
tsr2 = torch.tensor(np.array([5,10,3]))
my_list_tsr = [tsr1, tsr2]
B = torch.stack(my_list_tsr, dim = 1)
B
Output:
tensor([[ 1, 5],
[ 2, 10],
[ 3, 3]])
Code:
np.argmax(A, axis = 1)
Output:
array([1, 1, 0])
Code:
torch.argmax(B, dim = 1)
Output:
tensor([1, 1, 1])
Expected behavior
The returned indices should both be the same using numpy.argmax()
or torch.argmax()
, but the returned index for the 3rd element is different.
Environment
Collecting environment information...
PyTorch version: 1.5.1+cu101
Is debug build: No
CUDA used to build PyTorch: 10.1
OS: Ubuntu 18.04.3 LTS
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
CMake version: version 3.12.0
Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 10.1.243
GPU models and configuration: GPU 0: Tesla K80
Nvidia driver version: 418.67
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
Versions of relevant libraries:
[pip3] numpy==1.18.5
[pip3] torch==1.5.1+cu101
[pip3] torchsummary==1.5.1
[pip3] torchtext==0.3.1
[pip3] torchvision==0.6.1+cu101
[conda] Could not collect
Additional context
For consistency, would be helpful if torch.argmax()
returns the same indices to numpy.argmax()
when the element values are the same, where numpy.argmax()
is the more commonly used function.