Skip to content

torch.argmax() returning different indices to numpy.argmax() when the element values are the same #41998

@leockl

Description

@leockl

🐛 Bug

torch.argmax() returning different indices to numpy.argmax() when the element values are the same.

To Reproduce

Code:

import torch
import numpy as np

arr1 = np.array([1,2,3])
arr2 = np.array([5,10,3])
my_list_arr = [arr1, arr2]
A = np.transpose(my_list_arr)
A

Output:

array([[ 1,  5],
       [ 2, 10],
       [ 3,  3]])

Code:

tsr1 = torch.tensor(np.array([1,2,3]))
tsr2 = torch.tensor(np.array([5,10,3]))
my_list_tsr = [tsr1, tsr2]
B = torch.stack(my_list_tsr, dim = 1)
B

Output:

tensor([[ 1,  5],
        [ 2, 10],
        [ 3,  3]])

Code:

np.argmax(A, axis = 1)

Output:

array([1, 1, 0])

Code:

torch.argmax(B, dim = 1)

Output:

tensor([1, 1, 1])

Expected behavior

The returned indices should both be the same using numpy.argmax() or torch.argmax(), but the returned index for the 3rd element is different.

Environment

Collecting environment information...
PyTorch version: 1.5.1+cu101
Is debug build: No
CUDA used to build PyTorch: 10.1

OS: Ubuntu 18.04.3 LTS
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
CMake version: version 3.12.0

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 10.1.243
GPU models and configuration: GPU 0: Tesla K80
Nvidia driver version: 418.67
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5

Versions of relevant libraries:
[pip3] numpy==1.18.5
[pip3] torch==1.5.1+cu101
[pip3] torchsummary==1.5.1
[pip3] torchtext==0.3.1
[pip3] torchvision==0.6.1+cu101
[conda] Could not collect

Additional context

For consistency, would be helpful if torch.argmax() returns the same indices to numpy.argmax() when the element values are the same, where numpy.argmax() is the more commonly used function.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions