Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

index.Tensor_out doesn't properly handle dtype conversions #107698

Closed
manuelcandales opened this issue Aug 22, 2023 · 3 comments
Closed

index.Tensor_out doesn't properly handle dtype conversions #107698

manuelcandales opened this issue Aug 22, 2023 · 3 comments
Assignees
Labels
high priority module: advanced indexing Related to x[i] = y, index functions module: correctness (silent) issue that returns an incorrect result silently module: error checking Bugs related to incorrect/lacking error checking triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@manuelcandales
Copy link
Contributor

manuelcandales commented Aug 22, 2023

馃悰 Describe the bug

import torch

input = torch.tensor([[1, 2], [3, 4]], dtype=torch.int)
indices = [torch.tensor([[True, False], [False, True]], dtype=torch.bool)]

# Both should return tensor [1, 4], but it seems index doesn't properly handle dtype conversions
print(torch.ops.aten.index.Tensor_out(input, indices, out=torch.tensor([0, 0], dtype=torch.int)))
print(torch.ops.aten.index.Tensor_out(input, indices, out=torch.tensor([0, 0], dtype=torch.long)))

Result:

tensor([1, 4], dtype=torch.int32)
tensor([8589934593,          4])

Versions

Collecting environment information...
PyTorch version: 2.0.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 13.5 (arm64)
GCC version: Could not collect
Clang version: 14.0.0 (clang-1400.0.29.202)
CMake version: Could not collect
Libc version: N/A

Python version: 3.9.6 (default, Oct 18 2022, 12:41:40) [Clang 14.0.0 (clang-1400.0.29.202)] (64-bit runtime)
Python platform: macOS-13.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Pro

Versions of relevant libraries:
[pip3] numpy==1.24.3
[pip3] torch==2.0.1
[conda] Could not collect

cc @ezyang @gchanan @zou3519 @malfet

@cpuhrsch cpuhrsch added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: advanced indexing Related to x[i] = y, index functions high priority labels Aug 22, 2023
@cpuhrsch
Copy link
Contributor

Not sure what the right label is again

@malfet
Copy link
Contributor

malfet commented Aug 28, 2023

Issue is not specific to ARM, and IMO it just needs an error checking, as we usually don't do type promotion in out variant, grabbing for myself, fix is coming

@malfet malfet self-assigned this Aug 28, 2023
@malfet malfet added triage review module: error checking Bugs related to incorrect/lacking error checking and removed triage review labels Aug 28, 2023
@malfet
Copy link
Contributor

malfet commented Aug 29, 2023

Though, surprisingly we support type promotion in say binary ops:

import torch
torch.add(torch.rand(3), torch.rand(3),out=torch.empty(3,dtype=torch.float16))
# tensor([0.8267, 0.2036, 1.7734], dtype=torch.float16)

But not for say index_put

import torch
x=torch.arange(4, dtype=torch.int).resize(2, 2)
x[torch.tensor([[True, False], [False, True]])]=torch.empty(2, dtype=torch.float)
# Raises Index put requires the source and destination dtypes match

malfet added a commit that referenced this issue Aug 29, 2023
This logic exists for index_put and index_add, but for some reason not for index_get

Fixes #107698
@malfet malfet added the module: correctness (silent) issue that returns an incorrect result silently label Aug 29, 2023
facebook-github-bot pushed a commit to pytorch/executorch that referenced this issue Sep 14, 2023
Summary:
Pull Request resolved: #321

Previously reported index [issue](pytorch/pytorch#107698) related to input/out type conversion was fixed in ATen, by requiring out tensor to have the same dtype than input tensor.
This diffs updates the portable index kernel, to stay Dtype-compliant.

Reviewed By: larryliu0820

Differential Revision: D49218705

fbshipit-source-id: 0aa16a8ffd666c69b9386bd6b2a8c1e23398872c
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: advanced indexing Related to x[i] = y, index functions module: correctness (silent) issue that returns an incorrect result silently module: error checking Bugs related to incorrect/lacking error checking triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants