-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Replace all RAIIATH with Tensor in libtorch_agnostic test, test some APIs #155977
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
72aba5c
1be5a63
ba01acb
4c80667
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,12 +1,9 @@ | ||
| #include <torch/csrc/inductor/aoti_torch/c/shim.h> | ||
| #include <torch/csrc/inductor/aoti_runtime/utils.h> | ||
| #include <torch/csrc/stable/library.h> | ||
| #include <torch/csrc/stable/tensor.h> | ||
|
|
||
| #include <optional> | ||
|
|
||
| using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle; | ||
|
|
||
| void inline sgd_math( | ||
| float* param_ptr, | ||
| float* grad_ptr, | ||
|
|
@@ -27,17 +24,14 @@ void inline sgd_math( | |
| } | ||
| } | ||
|
|
||
| using torch::stable::Tensor; | ||
|
|
||
| RAIIATH sgd_out_of_place( | ||
| const RAIIATH param, | ||
| const RAIIATH grad, | ||
| Tensor sgd_out_of_place( | ||
| const Tensor param, | ||
| const Tensor grad, | ||
| const float weight_decay, | ||
| const double lr, | ||
| const bool maximize) { | ||
|
|
||
| int64_t param_dim; | ||
| aoti_torch_get_dim(param.get(), ¶m_dim); | ||
|
|
||
| int64_t *param_sizes; | ||
| int64_t *param_strides; | ||
| aoti_torch_get_sizes(param.get(), ¶m_sizes); | ||
|
|
@@ -47,56 +41,34 @@ RAIIATH sgd_out_of_place( | |
| aoti_torch_get_dtype(param.get(), ¶m_dtype); | ||
|
|
||
| int32_t param_device_type; | ||
| int32_t param_device_index; | ||
| aoti_torch_get_device_type(param.get(), ¶m_device_type); | ||
| aoti_torch_get_device_index(param.get(), ¶m_device_index); | ||
|
|
||
| AtenTensorHandle out; | ||
| aoti_torch_empty_strided(param_dim, param_sizes, param_strides, param_dtype, param_device_type, param_device_index, &out); | ||
|
|
||
| void* param_ptr; | ||
| aoti_torch_get_data_ptr(param.get(), ¶m_ptr); | ||
| void* grad_ptr; | ||
| aoti_torch_get_data_ptr(grad.get(), &grad_ptr); | ||
| void* out_ptr; | ||
| aoti_torch_get_data_ptr(out, &out_ptr); | ||
|
|
||
| auto param_fp_ptr = reinterpret_cast<float*>(param_ptr); | ||
| auto grad_fp_ptr = reinterpret_cast<float*>(grad_ptr); | ||
| auto out_fp_ptr = reinterpret_cast<float*>(out_ptr); | ||
|
|
||
| int64_t param_numel; | ||
| aoti_torch_get_numel(param.get(), ¶m_numel); | ||
| AtenTensorHandle out_ath; | ||
| aoti_torch_empty_strided(param.dim(), param_sizes, param_strides, param_dtype, param_device_type, param.get_device(), &out_ath); | ||
| auto out = Tensor(out_ath); | ||
|
|
||
| sgd_math( | ||
| param_fp_ptr, | ||
| grad_fp_ptr, | ||
| out_fp_ptr, | ||
| reinterpret_cast<float*>(param.data_ptr()), | ||
| reinterpret_cast<float*>(grad.data_ptr()), | ||
| reinterpret_cast<float*>(out.data_ptr()), | ||
| weight_decay, | ||
| lr, | ||
| maximize, | ||
| param_numel | ||
| param.numel() | ||
| ); | ||
|
|
||
| return RAIIATH(out); | ||
| return out; | ||
| } | ||
|
|
||
|
|
||
| void boxed_sgd_out_of_place(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { | ||
| RAIIATH param(to<AtenTensorHandle>(stack[0])); | ||
| RAIIATH grad(to<AtenTensorHandle>(stack[1])); | ||
| auto weight_decay = to<double>(stack[2]); | ||
| auto lr = to<double>(stack[3]); | ||
| auto maximize = to<bool>(stack[4]); | ||
|
|
||
| RAIIATH raiiath_res = sgd_out_of_place( | ||
| std::move(param), | ||
| std::move(grad), | ||
| float(weight_decay), | ||
| lr, | ||
| maximize); | ||
| Tensor res = sgd_out_of_place( | ||
| to<Tensor>(stack[0]), | ||
| to<Tensor>(stack[1]), | ||
| float(to<double>(stack[2])), | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why "float()" here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it doesn't have to be, i just defined it this way to show someone how they'd use float if their kernel takes float. |
||
| to<double>(stack[3]), | ||
| to<bool>(stack[4])); | ||
|
|
||
| stack[0] = from(raiiath_res.release()); | ||
| stack[0] = from(res); | ||
| } | ||
|
|
||
| STABLE_TORCH_LIBRARY(libtorch_agnostic, m) { | ||
|
|
@@ -107,14 +79,13 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) { | |
| m.impl("sgd_out_of_place", &boxed_sgd_out_of_place); | ||
| } | ||
|
|
||
| RAIIATH identity(RAIIATH t) { | ||
| return std::move(t); | ||
| Tensor identity(Tensor t) { | ||
| return t; | ||
| } | ||
|
|
||
| void boxed_identity(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { | ||
| RAIIATH t(to<AtenTensorHandle>(stack[0])); | ||
| RAIIATH raiiath_res = identity(std::move(t)); | ||
| stack[0] = from(raiiath_res.release()); | ||
| Tensor res = identity(to<Tensor>(stack[0])); | ||
| stack[0] = from(res); | ||
| } | ||
|
|
||
| STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { | ||
|
|
@@ -129,8 +100,6 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) { | |
| m.impl("identity", &boxed_identity); | ||
| } | ||
|
|
||
| using torch::stable::Tensor; | ||
|
|
||
| Tensor my_abs(Tensor t) { | ||
| const auto num_args = 1; | ||
| StableIValue stack[num_args]; | ||
|
|
@@ -152,15 +121,15 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { | |
| m.impl("my_abs", &boxed_my_abs); | ||
| } | ||
|
|
||
| RAIIATH my_ones_like(RAIIATH t, StableIValue device) { | ||
| Tensor my_ones_like(Tensor t, StableIValue device) { | ||
| const auto num_args = 6; | ||
| StableIValue stack[num_args]; | ||
|
|
||
| int32_t t_dtype; | ||
| aoti_torch_get_dtype(t.get(), &t_dtype); | ||
| auto mf = aoti_torch_memory_format_contiguous_format(); | ||
|
|
||
| stack[0] = from(t.release()); | ||
| stack[0] = from(t); | ||
| stack[1] = from(std::optional(t_dtype)); // dtype | ||
| stack[2] = from(std::nullopt); // layout | ||
| stack[3] = from(std::optional(device)); // device | ||
|
|
@@ -169,15 +138,12 @@ RAIIATH my_ones_like(RAIIATH t, StableIValue device) { | |
|
|
||
| aoti_torch_call_dispatcher("aten::ones_like", "", stack); | ||
|
|
||
| return RAIIATH(to<AtenTensorHandle>(stack[0])); | ||
| return to<Tensor>(stack[0]); | ||
| } | ||
|
|
||
| void boxed_my_ones_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { | ||
| RAIIATH t(to<AtenTensorHandle>(stack[0])); | ||
| StableIValue device = stack[1]; | ||
|
|
||
| RAIIATH raiiath_res = my_ones_like(std::move(t), device); | ||
| stack[0] = from(raiiath_res.release()); | ||
| Tensor res = my_ones_like(to<Tensor>(stack[0]), stack[1]); | ||
| stack[0] = from(res); | ||
| } | ||
|
|
||
| STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { | ||
|
|
@@ -188,32 +154,29 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { | |
| m.impl("my_ones_like", &boxed_my_ones_like); | ||
| } | ||
|
|
||
| std::tuple<RAIIATH, RAIIATH, bool> exp_neg_is_leaf(RAIIATH t1, RAIIATH t2, RAIIATH t3) { | ||
| StableIValue stack1[1]; | ||
| stack1[0] = from(t1.release()); | ||
| aoti_torch_call_dispatcher("aten::exp", "", stack1); | ||
| std::tuple<Tensor, Tensor, bool> exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) { | ||
| StableIValue stack_exp[1]; | ||
| stack_exp[0] = from(t1); | ||
| aoti_torch_call_dispatcher("aten::exp", "", stack_exp); | ||
|
|
||
| StableIValue stack2[1]; | ||
| stack2[0] = from(t2.release()); | ||
| aoti_torch_call_dispatcher("aten::neg", "", stack2); | ||
| StableIValue stack_neg[1]; | ||
| stack_neg[0] = from(t2); | ||
| aoti_torch_call_dispatcher("aten::neg", "", stack_neg); | ||
|
|
||
| StableIValue stack3[1]; | ||
| stack3[0] = from(t3.release()); | ||
| aoti_torch_call_dispatcher("aten::is_leaf", "", stack3); | ||
| StableIValue stack_is_leaf[1]; | ||
| stack_is_leaf[0] = from(t3); | ||
| aoti_torch_call_dispatcher("aten::is_leaf", "", stack_is_leaf); | ||
|
|
||
| return std::make_tuple( | ||
| RAIIATH(to<AtenTensorHandle>(stack1[0])), | ||
| RAIIATH(to<AtenTensorHandle>(stack2[0])), | ||
| to<bool>(stack3[0])); | ||
| to<Tensor>(stack_exp[0]), | ||
| to<Tensor>(stack_neg[0]), | ||
| to<bool>(stack_is_leaf[0])); | ||
| } | ||
|
|
||
| void boxed_exp_neg_is_leaf(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { | ||
| RAIIATH t1(to<AtenTensorHandle>(stack[0])); | ||
| RAIIATH t2(to<AtenTensorHandle>(stack[1])); | ||
| RAIIATH t3(to<AtenTensorHandle>(stack[2])); | ||
| auto tuple = exp_neg_is_leaf(std::move(t1), std::move(t2), std::move(t3)); | ||
| stack[0] = from(std::get<0>(tuple).release()); | ||
| stack[1] = from(std::get<1>(tuple).release()); | ||
| auto tuple = exp_neg_is_leaf(to<Tensor>(stack[0]), to<Tensor>(stack[1]), to<Tensor>(stack[2])); | ||
| stack[0] = from(std::get<0>(tuple)); | ||
| stack[1] = from(std::get<1>(tuple)); | ||
| stack[2] = from(std::get<2>(tuple)); | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -69,6 +69,18 @@ class Tensor { | |
| return data_ptr; | ||
| } | ||
|
|
||
| int64_t dim() const { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are these signed instead of unsigned anyway... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The same answer as many of these things in PyTorch: Historical reasons :D I would mention in this case, that since all the function that do take a "dim" argument must be signed (to handle -1 dim etc), then it is more convenient to have the same type throughout for everything that represents "dim" and avoid casting from unsigned to signed everywhere. |
||
| int64_t dim; | ||
| AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dim(ath_.get(), &dim)); | ||
| return dim; | ||
| } | ||
|
|
||
| int64_t numel() const { | ||
| int64_t numel; | ||
| AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_numel(ath_.get(), &numel)); | ||
| return numel; | ||
| } | ||
|
|
||
| int64_t stride(int64_t dim) const { | ||
| int64_t stride; | ||
| AOTI_TORCH_ERROR_CODE_CHECK( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
side thought: I feel like the templated data_ptr() accessors (with dtype check) would be nice to help users to get safer code (by not casting a float16 to float32 by mistake) and const correctness (to get the right COW behavior).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As in for an API to add to stable::Tensor, right?