Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] Refactor some libtorch c shim interfaces #109834

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,7 +1019,7 @@ def generate_kernel_call(
var_name = f"var_{next(self.arg_var_id)}"
self.writeline(f"void *{var_name}{self.ending}")
self.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(&{var_name}, {arg}));"
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr({arg}, &{var_name}));"
)
dtype = V.graph.get_dtype(arg)
cpp_dtype = DTYPE_TO_CPP[dtype]
Expand Down Expand Up @@ -1215,7 +1215,7 @@ def codegen_input_size_var_decl(self, code: IndentedBuffer, name):
if config.aot_inductor.abi_compatible:
code.writeline(f"int64_t* {name}_size;")
code.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(&{name}_size, {name}));"
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes({name}, &{name}_size));"
)
else:
super().codegen_input_size_var_decl(code, name)
Expand All @@ -1224,7 +1224,7 @@ def codegen_input_stride_var_decl(self, code: IndentedBuffer, name):
if config.aot_inductor.abi_compatible:
code.writeline(f"int64_t* {name}_stride;")
code.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(&{name}_stride, {name}));"
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides({name}, &{name}_stride));"
)
else:
super().codegen_input_stride_var_decl(code, name)
Expand Down Expand Up @@ -1587,13 +1587,13 @@ def make_buffer_allocation(self, buffer):
if config.aot_inductor.abi_compatible:
device_type, device_id = device.split(",")
args = [
f"&{name}_handle",
str(len(buffer.get_size())),
self.codegen_int_array_var(size, self.wrapper_call),
self.codegen_int_array_var(stride, self.wrapper_call),
dtype,
device_type,
device_id,
f"&{name}_handle",
]
self.wrapper_call.writeline(f"AtenTensorHandle {name}_handle;")
self.wrapper_call.writeline(
Expand All @@ -1619,12 +1619,12 @@ def codegen_reinterpret_view(self, name, size, stride, offset, writer) -> str:
if writer is None:
writer = self
args = [
f"&{tmp_name}",
f"{name}",
dim,
self.codegen_int_array_var(size, writer),
self.codegen_int_array_var(stride, writer),
offset,
f"&{tmp_name}",
]
writer.writeline(f"AtenTensorHandle {tmp_name};")
writer.writeline(
Expand Down Expand Up @@ -2025,7 +2025,7 @@ def generate_args_decl(self, call_args):
if config.aot_inductor.abi_compatible:
self.writeline(f"CUdeviceptr {var_name};")
self.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(reinterpret_cast<void**>(&{var_name}), {arg}));"
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr({arg}, reinterpret_cast<void**>(&{var_name})));"
)
else:
self.writeline(
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/inductor/aot_runtime/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ class AOTICudaStreamGuard {
AOTICudaStreamGuard(cudaStream_t stream, int32_t device_index) {
CUDAStreamGuardHandle ptr;
AOTI_TORCH_ERROR_CODE_CHECK(
aoti_torch_create_cuda_stream_guard(&ptr, stream, device_index));
aoti_torch_create_cuda_stream_guard(stream, device_index, &ptr));
guard_ =
std::unique_ptr<void, std::function<void(void*)>>(ptr, [](void* ptr) {
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_delete_cuda_stream_guard(
Expand Down
5 changes: 2 additions & 3 deletions torch/csrc/inductor/aot_runtime/model_container.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,16 +139,15 @@ class AOTInductorModelContainer {

AtenTensorHandle tensor_handle;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob(
&tensor_handle,
internal_ptr,
ndim,
size,
stride,
offset,
dtype,
device_type,
0 // device index, should read it from cudaStream_t?
));
0, // device index, should read it from cudaStream_t?
&tensor_handle));
constants_->emplace(
std::move(name), std::move(RAIIAtenTensorHandle(tensor_handle)));
}
Expand Down
35 changes: 21 additions & 14 deletions torch/csrc/inductor/aoti_torch/c/shim.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,52 +99,58 @@ aoti_torch_delete_tensor_object(AtenTensorHandle tensor);

// Get a pointer to the underlying storage data
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_data_ptr(
void** ret, // returns borrowed reference
AtenTensorHandle tensor);
AtenTensorHandle tensor,
void** ret_data_ptr // returns borrowed reference
);

AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_sizes(
int64_t** ret, // returns borrowed reference
AtenTensorHandle tensor);
AtenTensorHandle tensor,
int64_t** ret_size // returns borrowed reference
);

AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_strides(
int64_t** ret, // returns borrowed reference
AtenTensorHandle tensor);
AtenTensorHandle tensor,
int64_t** ret_strides // returns borrowed reference
);

// This function will create a new tensor object and its pointer is returned
// through *out. The caller is responsible for wrapping the tensor pointer
// with RAIIAtenTensorHandle which will call aoti_torch_delete_tensor_object
// when going out of scope.
AOTI_TORCH_EXPORT AOTITorchError aoti_torch__reinterpret_tensor(
AtenTensorHandle* ret, // returns new reference
AtenTensorHandle self,
int64_t ndim,
const int64_t* sizes_ptr,
const int64_t* strides_ptr,
int64_t storage_offset);
int64_t storage_offset,
AtenTensorHandle* ret_new_tensor // returns new reference
);

// This function will create a new tensor object and its pointer is returned
// through *out. The caller is responsible for wrapping the tensor pointer
// with RAIIAtenTensorHandle which will call aoti_torch_delete_tensor_object
// when going out of scope.
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_empty_strided(
AtenTensorHandle* ret, // returns new reference
int64_t ndim,
const int64_t* sizes_ptr,
const int64_t* strides_ptr,
int32_t dtype,
int32_t device_type,
int32_t device_index);
int32_t device_index,
AtenTensorHandle* ret_new_tensor // returns new reference
);

AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob(
AtenTensorHandle* ret, // returns new reference
void* data,
int64_t ndim,
const int64_t* sizes_ptr,
const int64_t* strides_ptr,
int64_t storage_offset,
int32_t dtype,
int32_t device_type,
int32_t device_index);
int32_t device_index,
AtenTensorHandle* ret // returns new reference
);

AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_tensor_copy_(AtenTensorHandle src, AtenTensorHandle dst);
Expand Down Expand Up @@ -173,9 +179,10 @@ struct CUDAStreamGuardOpaque;
using CUDAStreamGuardHandle = CUDAStreamGuardOpaque*;

AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_cuda_stream_guard(
CUDAStreamGuardHandle* ret_guard, // returns new reference
void* stream,
int32_t device_index);
int32_t device_index,
CUDAStreamGuardHandle* ret_guard // returns new reference
);

AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_delete_cuda_stream_guard(CUDAStreamGuardHandle guard);
Expand Down
42 changes: 24 additions & 18 deletions torch/csrc/inductor/aoti_torch/shim_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,88 +91,94 @@ AOTITorchError aoti_torch_delete_tensor_object(AtenTensorHandle tensor) {
});
}

AOTITorchError aoti_torch_get_data_ptr(void** ret, AtenTensorHandle tensor) {
AOTITorchError aoti_torch_get_data_ptr(
AtenTensorHandle tensor,
void** ret_data_ptr) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
*ret = t->data_ptr();
*ret_data_ptr = t->data_ptr();
});
}

AOTITorchError aoti_torch_get_sizes(int64_t** ret, AtenTensorHandle tensor) {
AOTITorchError aoti_torch_get_sizes(
AtenTensorHandle tensor,
int64_t** ret_sizes) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
*ret = const_cast<int64_t*>(t->sizes().data());
*ret_sizes = const_cast<int64_t*>(t->sizes().data());
});
}

AOTITorchError aoti_torch_get_strides(int64_t** ret, AtenTensorHandle tensor) {
AOTITorchError aoti_torch_get_strides(
AtenTensorHandle tensor,
int64_t** ret_strides) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
*ret = const_cast<int64_t*>(t->strides().data());
*ret_strides = const_cast<int64_t*>(t->strides().data());
});
}

AOTITorchError aoti_torch__reinterpret_tensor(
AtenTensorHandle* ret,
AtenTensorHandle self,
int64_t ndim,
const int64_t* sizes_ptr,
const int64_t* strides_ptr,
int64_t offset_increment) {
int64_t offset_increment,
AtenTensorHandle* ret_new_tensor) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self);
c10::IntArrayRef sizes(sizes_ptr, ndim);
c10::IntArrayRef strides(strides_ptr, ndim);
at::Tensor* out_tensor =
at::Tensor* new_tensor =
new at::Tensor(torch::inductor::_reinterpret_tensor(
*self_tensor, sizes, strides, offset_increment));
*ret = tensor_pointer_to_tensor_handle(out_tensor);
*ret_new_tensor = tensor_pointer_to_tensor_handle(new_tensor);
});
}

// TODO: implement a more efficient version instead of calling into aten
AOTITorchError aoti_torch_empty_strided(
AtenTensorHandle* ret,
int64_t ndim,
const int64_t* sizes_ptr,
const int64_t* strides_ptr,
int32_t dtype,
int32_t device_type,
int32_t device_index) {
int32_t device_index,
AtenTensorHandle* ret_new_tensor) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
c10::IntArrayRef sizes(sizes_ptr, ndim);
c10::IntArrayRef strides(strides_ptr, ndim);
c10::Device device = c10_device(device_type, device_index);
c10::TensorOptions options = c10::TensorOptions().device(device).dtype(
static_cast<c10::ScalarType>(dtype));
at::Tensor* out_tensor =
at::Tensor* new_tensor =
new at::Tensor(at::empty_strided(sizes, strides, options));
*ret = tensor_pointer_to_tensor_handle(out_tensor);
*ret_new_tensor = tensor_pointer_to_tensor_handle(new_tensor);
});
}

AOTITorchError aoti_torch_create_tensor_from_blob(
AtenTensorHandle* ret,
void* data,
int64_t ndim,
const int64_t* sizes_ptr,
const int64_t* strides_ptr,
int64_t storage_offset,
int32_t dtype,
int32_t device_type,
int32_t device_index) {
int32_t device_index,
AtenTensorHandle* ret_new_tensor) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
c10::IntArrayRef sizes(sizes_ptr, ndim);
c10::IntArrayRef strides(strides_ptr, ndim);
c10::Device device = c10_device(device_type, device_index);
c10::TensorOptions options = c10::TensorOptions().device(device).dtype(
static_cast<c10::ScalarType>(dtype));
at::Tensor* out_tensor = new at::Tensor(at::for_blob(data, sizes)
at::Tensor* new_tensor = new at::Tensor(at::for_blob(data, sizes)
.strides(strides)
.storage_offset(storage_offset)
.options(options)
.make_tensor());
*ret = tensor_pointer_to_tensor_handle(out_tensor);
*ret_new_tensor = tensor_pointer_to_tensor_handle(new_tensor);
});
}

Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/inductor/aoti_torch/shim_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
#include <c10/cuda/CUDAStream.h>

AOTITorchError aoti_torch_create_cuda_stream_guard(
CUDAStreamGuardHandle* ret_guard,
void* stream,
int32_t device_index) {
int32_t device_index,
CUDAStreamGuardHandle* ret_guard) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
at::cuda::CUDAStreamGuard* guard =
new at::cuda::CUDAStreamGuard(at::cuda::getStreamFromExternal(
Expand Down
Loading