Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/inductor/aoti_runtime/utils.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>

#include <optional>

using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle;

void inline sgd_math(
float* param_ptr,
float* grad_ptr,
Expand All @@ -27,17 +24,14 @@ void inline sgd_math(
}
}

using torch::stable::Tensor;

RAIIATH sgd_out_of_place(
const RAIIATH param,
const RAIIATH grad,
Tensor sgd_out_of_place(
const Tensor param,
const Tensor grad,
const float weight_decay,
const double lr,
const bool maximize) {

int64_t param_dim;
aoti_torch_get_dim(param.get(), &param_dim);

int64_t *param_sizes;
int64_t *param_strides;
aoti_torch_get_sizes(param.get(), &param_sizes);
Expand All @@ -47,56 +41,34 @@ RAIIATH sgd_out_of_place(
aoti_torch_get_dtype(param.get(), &param_dtype);

int32_t param_device_type;
int32_t param_device_index;
aoti_torch_get_device_type(param.get(), &param_device_type);
aoti_torch_get_device_index(param.get(), &param_device_index);

AtenTensorHandle out;
aoti_torch_empty_strided(param_dim, param_sizes, param_strides, param_dtype, param_device_type, param_device_index, &out);

void* param_ptr;
aoti_torch_get_data_ptr(param.get(), &param_ptr);
void* grad_ptr;
aoti_torch_get_data_ptr(grad.get(), &grad_ptr);
void* out_ptr;
aoti_torch_get_data_ptr(out, &out_ptr);

auto param_fp_ptr = reinterpret_cast<float*>(param_ptr);
auto grad_fp_ptr = reinterpret_cast<float*>(grad_ptr);
auto out_fp_ptr = reinterpret_cast<float*>(out_ptr);

int64_t param_numel;
aoti_torch_get_numel(param.get(), &param_numel);
AtenTensorHandle out_ath;
aoti_torch_empty_strided(param.dim(), param_sizes, param_strides, param_dtype, param_device_type, param.get_device(), &out_ath);
auto out = Tensor(out_ath);

sgd_math(
param_fp_ptr,
grad_fp_ptr,
out_fp_ptr,
reinterpret_cast<float*>(param.data_ptr()),
reinterpret_cast<float*>(grad.data_ptr()),
reinterpret_cast<float*>(out.data_ptr()),
Comment on lines +51 to +53
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

side thought: I feel like the templated data_ptr() accessors (with dtype check) would be nice to help users to get safer code (by not casting a float16 to float32 by mistake) and const correctness (to get the right COW behavior).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As in for an API to add to stable::Tensor, right?

weight_decay,
lr,
maximize,
param_numel
param.numel()
);

return RAIIATH(out);
return out;
}


void boxed_sgd_out_of_place(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
RAIIATH param(to<AtenTensorHandle>(stack[0]));
RAIIATH grad(to<AtenTensorHandle>(stack[1]));
auto weight_decay = to<double>(stack[2]);
auto lr = to<double>(stack[3]);
auto maximize = to<bool>(stack[4]);

RAIIATH raiiath_res = sgd_out_of_place(
std::move(param),
std::move(grad),
float(weight_decay),
lr,
maximize);
Tensor res = sgd_out_of_place(
to<Tensor>(stack[0]),
to<Tensor>(stack[1]),
float(to<double>(stack[2])),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why "float()" here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it doesn't have to be, i just defined it this way to show someone how they'd use float if their kernel takes float.

to<double>(stack[3]),
to<bool>(stack[4]));

stack[0] = from(raiiath_res.release());
stack[0] = from(res);
}

STABLE_TORCH_LIBRARY(libtorch_agnostic, m) {
Expand All @@ -107,14 +79,13 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
m.impl("sgd_out_of_place", &boxed_sgd_out_of_place);
}

RAIIATH identity(RAIIATH t) {
return std::move(t);
Tensor identity(Tensor t) {
return t;
}

void boxed_identity(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
RAIIATH t(to<AtenTensorHandle>(stack[0]));
RAIIATH raiiath_res = identity(std::move(t));
stack[0] = from(raiiath_res.release());
Tensor res = identity(to<Tensor>(stack[0]));
stack[0] = from(res);
}

STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
Expand All @@ -129,8 +100,6 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
m.impl("identity", &boxed_identity);
}

using torch::stable::Tensor;

Tensor my_abs(Tensor t) {
const auto num_args = 1;
StableIValue stack[num_args];
Expand All @@ -152,15 +121,15 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("my_abs", &boxed_my_abs);
}

RAIIATH my_ones_like(RAIIATH t, StableIValue device) {
Tensor my_ones_like(Tensor t, StableIValue device) {
const auto num_args = 6;
StableIValue stack[num_args];

int32_t t_dtype;
aoti_torch_get_dtype(t.get(), &t_dtype);
auto mf = aoti_torch_memory_format_contiguous_format();

stack[0] = from(t.release());
stack[0] = from(t);
stack[1] = from(std::optional(t_dtype)); // dtype
stack[2] = from(std::nullopt); // layout
stack[3] = from(std::optional(device)); // device
Expand All @@ -169,15 +138,12 @@ RAIIATH my_ones_like(RAIIATH t, StableIValue device) {

aoti_torch_call_dispatcher("aten::ones_like", "", stack);

return RAIIATH(to<AtenTensorHandle>(stack[0]));
return to<Tensor>(stack[0]);
}

void boxed_my_ones_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
RAIIATH t(to<AtenTensorHandle>(stack[0]));
StableIValue device = stack[1];

RAIIATH raiiath_res = my_ones_like(std::move(t), device);
stack[0] = from(raiiath_res.release());
Tensor res = my_ones_like(to<Tensor>(stack[0]), stack[1]);
stack[0] = from(res);
}

STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
Expand All @@ -188,32 +154,29 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("my_ones_like", &boxed_my_ones_like);
}

std::tuple<RAIIATH, RAIIATH, bool> exp_neg_is_leaf(RAIIATH t1, RAIIATH t2, RAIIATH t3) {
StableIValue stack1[1];
stack1[0] = from(t1.release());
aoti_torch_call_dispatcher("aten::exp", "", stack1);
std::tuple<Tensor, Tensor, bool> exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) {
StableIValue stack_exp[1];
stack_exp[0] = from(t1);
aoti_torch_call_dispatcher("aten::exp", "", stack_exp);

StableIValue stack2[1];
stack2[0] = from(t2.release());
aoti_torch_call_dispatcher("aten::neg", "", stack2);
StableIValue stack_neg[1];
stack_neg[0] = from(t2);
aoti_torch_call_dispatcher("aten::neg", "", stack_neg);

StableIValue stack3[1];
stack3[0] = from(t3.release());
aoti_torch_call_dispatcher("aten::is_leaf", "", stack3);
StableIValue stack_is_leaf[1];
stack_is_leaf[0] = from(t3);
aoti_torch_call_dispatcher("aten::is_leaf", "", stack_is_leaf);

return std::make_tuple(
RAIIATH(to<AtenTensorHandle>(stack1[0])),
RAIIATH(to<AtenTensorHandle>(stack2[0])),
to<bool>(stack3[0]));
to<Tensor>(stack_exp[0]),
to<Tensor>(stack_neg[0]),
to<bool>(stack_is_leaf[0]));
}

void boxed_exp_neg_is_leaf(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
RAIIATH t1(to<AtenTensorHandle>(stack[0]));
RAIIATH t2(to<AtenTensorHandle>(stack[1]));
RAIIATH t3(to<AtenTensorHandle>(stack[2]));
auto tuple = exp_neg_is_leaf(std::move(t1), std::move(t2), std::move(t3));
stack[0] = from(std::get<0>(tuple).release());
stack[1] = from(std::get<1>(tuple).release());
auto tuple = exp_neg_is_leaf(to<Tensor>(stack[0]), to<Tensor>(stack[1]), to<Tensor>(stack[2]));
stack[0] = from(std::get<0>(tuple));
stack[1] = from(std::get<1>(tuple));
stack[2] = from(std::get<2>(tuple));
}

Expand Down
12 changes: 12 additions & 0 deletions torch/csrc/stable/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,18 @@ class Tensor {
return data_ptr;
}

int64_t dim() const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are these signed instead of unsigned anyway...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same answer as many of these things in PyTorch: Historical reasons :D

I would mention in this case, that since all the function that do take a "dim" argument must be signed (to handle -1 dim etc), then it is more convenient to have the same type throughout for everything that represents "dim" and avoid casting from unsigned to signed everywhere.

int64_t dim;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dim(ath_.get(), &dim));
return dim;
}

int64_t numel() const {
int64_t numel;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_numel(ath_.get(), &numel));
return numel;
}

int64_t stride(int64_t dim) const {
int64_t stride;
AOTI_TORCH_ERROR_CODE_CHECK(
Expand Down
Loading