From 161bfa5c9f7a998946797dc33f97b2cf9a31bcb3 Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Thu, 7 Aug 2025 15:07:00 -0700 Subject: [PATCH 1/2] Add stable Tensor get_device_index, use more stable DeviceIndex [ghstack-poisoned] --- torch/csrc/stable/tensor.h | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/torch/csrc/stable/tensor.h b/torch/csrc/stable/tensor.h index 741da7e62e40..c3b7d11ccde0 100644 --- a/torch/csrc/stable/tensor.h +++ b/torch/csrc/stable/tensor.h @@ -1,13 +1,15 @@ #pragma once #include +#include #include +#include #include - namespace torch::stable { -using DeviceIndex = - int8_t; // this is from c10/core/Device.h and can be header only +// this is bigger than DeviceIndex in c10/core/Device.h but it is the type we +// can converge on in this world as DeviceIndex in libtorch is not stable. +using DeviceIndex = int32_t; // The torch::stable::Tensor class is a highlevel C++ wrapper around // the C shim Tensor APIs. We've modeled this class after TensorBase, as custom @@ -95,7 +97,26 @@ class Tensor { return stride; } - DeviceIndex get_device() const { + // This is almost the same API as the one in TensorBase.h, except + // we add a check that the returned device_index is within the + // range of int8_t. + int8_t get_device() const { + int32_t device_index; + TORCH_ERROR_CODE_CHECK( + aoti_torch_get_device_index(ath_.get(), &device_index)); + STD_TORCH_CHECK( + device_index >= std::numeric_limits::min() && + device_index <= std::numeric_limits::max(), + "The returned device index is out of range of return type int8_t, please use get_device_index() instead."); + return static_cast(device_index); + } + + // The same as get_device but with two differences: + // 1. it has a more suiting name + // 2. it returns a DeviceIndex, which is int32_t in this world + // that should be more stable than the likely shifting + // DeviceIndex in libtorch (it is int8_t that might become int16_t) + DeviceIndex get_device_index() const { int32_t device_index; TORCH_ERROR_CODE_CHECK( aoti_torch_get_device_index(ath_.get(), &device_index)); From 47e5af8f1793770c0864a476b4c318f517fe6f44 Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Thu, 7 Aug 2025 15:08:07 -0700 Subject: [PATCH 2/2] Update on "Add stable Tensor get_device_index, use more stable DeviceIndex" [ghstack-poisoned] --- torch/csrc/stable/tensor.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/csrc/stable/tensor.h b/torch/csrc/stable/tensor.h index c3b7d11ccde0..b135aed3fc9b 100644 --- a/torch/csrc/stable/tensor.h +++ b/torch/csrc/stable/tensor.h @@ -107,7 +107,7 @@ class Tensor { STD_TORCH_CHECK( device_index >= std::numeric_limits::min() && device_index <= std::numeric_limits::max(), - "The returned device index is out of range of return type int8_t, please use get_device_index() instead."); + "Device index is out of range of return type int8_t, please use get_device_index() instead."); return static_cast(device_index); } @@ -120,7 +120,7 @@ class Tensor { int32_t device_index; TORCH_ERROR_CODE_CHECK( aoti_torch_get_device_index(ath_.get(), &device_index)); - return static_cast(device_index); + return device_index; } bool is_cuda() const {