Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions backends/aoti/common_shims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,6 @@ aoti_torch_clone(Tensor* self, Tensor** ret_new_tensor) {
return Error::Internal;
}

AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_new_tensor_handle(Tensor* orig_handle, Tensor** new_handle) {
(void)orig_handle;
(void)new_handle;
throw std::runtime_error("Not implemented");
return Error::Internal;
}

AOTI_SHIM_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob(
void* data_ptr,
int64_t ndim,
Expand Down
3 changes: 0 additions & 3 deletions backends/aoti/common_shims.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,6 @@ aoti_torch_clone_preserve_strides(Tensor* self, Tensor** ret_new_tensor);
AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_clone(Tensor* self, Tensor** ret_new_tensor);

AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_new_tensor_handle(Tensor* orig_handle, Tensor** new_handle);

AOTI_SHIM_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob(
void* data_ptr,
int64_t ndim,
Expand Down
9 changes: 9 additions & 0 deletions backends/apple/metal/runtime/shims/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,15 @@ AOTITorchError aoti_torch__reinterpret_tensor(
return Error::Ok;
}

AOTITorchError aoti_torch_new_tensor_handle(
Tensor* orig_handle,
Tensor** new_handle) {
(void)orig_handle;
(void)new_handle;
throw std::runtime_error("Not implemented");
return Error::Internal;
}

// Cleanup function for clearing global state
void cleanup_memory() {
// Use aoti_torch_delete_tensor_object to properly delete each tensor
Expand Down
4 changes: 4 additions & 0 deletions backends/apple/metal/runtime/shims/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ AOTITorchError aoti_torch__reinterpret_tensor(
int64_t storage_offset,
AOTITensorHandle* ret_new_tensor);

AOTITorchError aoti_torch_new_tensor_handle(
Tensor* orig_handle,
Tensor** new_handle);

void cleanup_memory();

} // extern "C"
Expand Down
89 changes: 89 additions & 0 deletions backends/cuda/runtime/shims/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,95 @@ AOTITorchError aoti_torch__reinterpret_tensor(
return Error::Ok;
}

AOTITorchError aoti_torch_new_tensor_handle(
Tensor* orig_handle,
Tensor** new_handle) {
// Validate input parameters
ET_CHECK_OR_RETURN_ERROR(
orig_handle != nullptr,
InvalidArgument,
"aoti_torch_new_tensor_handle failed: orig_handle is null");

ET_CHECK_OR_RETURN_ERROR(
new_handle != nullptr,
InvalidArgument,
"aoti_torch_new_tensor_handle failed: new_handle is null");

// Get metadata from the original tensor
int64_t* sizes_ptr;
int64_t* strides_ptr;
int32_t dtype;
int32_t device_type;
int32_t device_index;

ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_sizes(orig_handle, &sizes_ptr));
ET_CHECK_OK_OR_RETURN_ERROR(
aoti_torch_get_strides(orig_handle, &strides_ptr));
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_dtype(orig_handle, &dtype));
ET_CHECK_OK_OR_RETURN_ERROR(
aoti_torch_get_device_type(orig_handle, &device_type));
ET_CHECK_OK_OR_RETURN_ERROR(
aoti_torch_get_device_index(orig_handle, &device_index));

int64_t ndim = orig_handle->dim();

// Validate dtype
ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(dtype));

// Ensure device_index is always 0
ET_CHECK_OR_RETURN_ERROR(
device_index == 0,
InvalidArgument,
"device_index must be 0, got: %d",
device_index);

// Get the original data pointer from the source tensor
void* data_ptr = orig_handle->mutable_data_ptr();
ET_CHECK_OR_RETURN_ERROR(
data_ptr != nullptr,
InvalidArgument,
"Source tensor has null data pointer");

// Check if the given memory is in the map
auto memory_it = memory_to_n_tensor.find(data_ptr);
ET_CHECK_OR_RETURN_ERROR(
memory_it != memory_to_n_tensor.end(),
InvalidArgument,
"Memory address %p is not being tracked by reference counting system",
data_ptr);

// Convert sizes and strides to vectors
std::vector<SizesType> sizes = convert_sizes_to_vector(ndim, sizes_ptr);
std::vector<StridesType> strides =
convert_strides_to_vector(ndim, sizes_ptr, strides_ptr);

// Create new tensor that shares the same memory as the original
// This is similar to PyTorch's Tensor copy constructor - creates a new
// tensor object that shares the same underlying storage
std::shared_ptr<Tensor> tensor = make_tensor(
sizes, // Same sizes as original
data_ptr, // Share the same memory from source tensor
{}, // dim_order (empty, will be auto-generated)
strides, // Same strides as original
dtype_to_scalar_type(dtype) // Same dtype as original
);

ET_CHECK_OR_RETURN_ERROR(
tensor != nullptr, InvalidArgument, "Failed to create new tensor handle");

// Store the tensor so it doesn't get destroyed
tensors.insert(tensor);

*new_handle = tensor.get();

// Increment the reference count for this memory address only if it is owned
// by tensor
memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN
? NOT_OWN
: memory_to_n_tensor[data_ptr] + 1;

return Error::Ok;
}
} // extern "C"

} // namespace executorch::backends::cuda
25 changes: 25 additions & 0 deletions backends/cuda/runtime/shims/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,31 @@ AOTI_SHIM_EXPORT AOTITorchError aoti_torch__reinterpret_tensor(
AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking);

/**
* Creates a new tensor handle from an existing one.
*
* This function creates a new tensor object that shares the same underlying
* memory as the original tensor. Similar to PyTorch's Tensor copy constructor,
* it creates a new handle/reference to the same data without performing a deep
* copy.
*
* The new tensor will:
* - Share the same memory/storage as the original tensor
* - Have the same shape, strides, and dtype as the original
* - Increment the reference count for the underlying memory (if owned)
*
* @param orig_handle Original tensor to create a new handle from (must not be
* null)
* @param new_handle Output pointer to store the new tensor handle (must not be
* null)
*
* @return Error::Ok on success, appropriate error code on failure:
* - Error::InvalidArgument: null pointers or invalid parameters
*/
AOTITorchError aoti_torch_new_tensor_handle(
Tensor* orig_handle,
Tensor** new_handle);

// Function to clear all tensors from internal storage
AOTI_SHIM_EXPORT void clear_all_tensors();
} // extern "C"
Expand Down
1 change: 1 addition & 0 deletions backends/cuda/runtime/shims/tests/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ def define_common_targets():
cuda_shim_cpp_unittest("aoti_torch_copy_")
cuda_shim_cpp_unittest("aoti_torch_cuda_guard")
cuda_shim_cpp_unittest("aoti_torch_cuda__weight_int4pack_mm")
cuda_shim_cpp_unittest("aoti_torch_new_tensor_handle")
Loading
Loading