Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ Tensor sgd_out_of_place(
const bool maximize) {
STD_TORCH_CHECK(param.dim() == 1, "param must be 1D");

// these test the get_device() and get_device_index() methods
// while ascertaining that we are still on CPU
STD_TORCH_CHECK(param.get_device() == -1, "CPU device index = -1");
STD_TORCH_CHECK(param.get_device_index() == -1, "CPU device index = -1");

int64_t *param_sizes;
int64_t *param_strides;
aoti_torch_get_sizes(param.get(), &param_sizes);
Expand Down
31 changes: 26 additions & 5 deletions torch/csrc/stable/tensor.h
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
#pragma once

#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/headeronly/util/Exception.h>
#include <torch/headeronly/util/shim_utils.h>
#include <climits>
#include <memory>

namespace torch::stable {

using DeviceIndex =
int8_t; // this is from c10/core/Device.h and can be header only
// this is bigger than DeviceIndex in c10/core/Device.h but it is the type we
// can converge on in this world as DeviceIndex in libtorch is not stable.
using DeviceIndex = int32_t;

// The torch::stable::Tensor class is a highlevel C++ wrapper around
// the C shim Tensor APIs. We've modeled this class after TensorBase, as custom
Expand Down Expand Up @@ -95,11 +97,30 @@ class Tensor {
return stride;
}

DeviceIndex get_device() const {
// This is almost the same API as the one in TensorBase.h, except
// we add a check that the returned device_index is within the
// range of int8_t.
int8_t get_device() const {
int32_t device_index;
TORCH_ERROR_CODE_CHECK(
aoti_torch_get_device_index(ath_.get(), &device_index));
STD_TORCH_CHECK(
device_index >= std::numeric_limits<int8_t>::min() &&
device_index <= std::numeric_limits<int8_t>::max(),
"Device index is out of range of return type int8_t, please use get_device_index() instead.");
return static_cast<int8_t>(device_index);
}

// The same as get_device but with two differences:
// 1. it has a more suiting name
// 2. it returns a DeviceIndex, which is int32_t in this world
// that should be more stable than the likely shifting
// DeviceIndex in libtorch (it is int8_t that might become int16_t)
DeviceIndex get_device_index() const {
int32_t device_index;
TORCH_ERROR_CODE_CHECK(
aoti_torch_get_device_index(ath_.get(), &device_index));
return static_cast<DeviceIndex>(device_index);
return device_index;
}

bool is_cuda() const {
Expand Down
Loading