Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 8 additions & 18 deletions examples/llm_manual/managed_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,28 +30,21 @@ class ManagedTensor {
using DimOrderType = exec_aten::DimOrderType;
/// The type used for elements of `strides()`.
using StridesType = exec_aten::StridesType;

ManagedTensor() = delete;

explicit ManagedTensor(
void* data,
const std::vector<SizesType>& sizes,
ScalarType dtype)
: dtype_(dtype), sizes_(sizes), data_ptr_(data) {
ssize_t dim = sizes.size();
dim_order_.resize(dim);
strides_.resize(dim);
for (size_t i = 0; i < dim; ++i) {
dim_order_[i] = i;
}
dim_order_to_stride_nocheck(
sizes.data(), dim_order_.data(), dim, strides_.data());
: sizes_(sizes) {
tensor_impl_ = std::make_unique<TensorImpl>(
dtype_,
dim,
dtype,
sizes_.size(),
sizes_.data(),
data_ptr_,
dim_order_.data(),
strides_.data(),
data,
nullptr,
nullptr,
TensorShapeDynamism::DYNAMIC_BOUND);
}

Expand All @@ -63,12 +56,9 @@ class ManagedTensor {
}

private:
void* data_ptr_ = nullptr;
std::unique_ptr<TensorImpl> tensor_impl_;
std::vector<SizesType> sizes_;
std::vector<StridesType> strides_;
std::vector<DimOrderType> dim_order_;
ScalarType dtype_;
};

} // namespace executor
} // namespace torch
73 changes: 44 additions & 29 deletions extension/aten_util/make_aten_functor_from_et_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
#endif
#include <ATen/native/Resize.h>
#include <executorch/extension/kernel_util/type_list.h>
#include <executorch/extension/runner_util/managed_tensor.h>
#include <executorch/runtime/core/evalue.h>
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
#include <torch/torch.h>

namespace torch {
Expand Down Expand Up @@ -107,25 +107,39 @@ struct type_convert<
typename remove_const_ref<ETensor>::type,
torch::executor::Tensor>>>
final {
public:
ATensor val;
std::unique_ptr<ManagedTensor> managed_tensor;
torch::executor::Tensor converted;
std::vector<exec_aten::SizesType> sizes;
explicit type_convert(ATensor value)
: val(value), converted(torch::executor::Tensor(nullptr)) {
for (auto size : val.sizes()) {
sizes.push_back(size);
}
torch::executor::ScalarType scalar_type =
static_cast<torch::executor::ScalarType>(val.scalar_type());
managed_tensor = std::make_unique<ManagedTensor>(
val.mutable_data_ptr(), sizes, scalar_type);
converted = managed_tensor->get_aliasing_tensor();
explicit type_convert(ATensor value) : value_(value) {
auto sizes = std::make_shared<std::vector<Tensor::SizesType>>(
value_.sizes().begin(), value_.sizes().end());
const ssize_t dim = sizes->size();
auto dim_order = std::make_shared<std::vector<Tensor::DimOrderType>>(dim);
auto strides = std::make_shared<std::vector<Tensor::StridesType>>(dim);

std::iota(dim_order->begin(), dim_order->end(), 0);
dim_order_to_stride_nocheck(
sizes->data(), dim_order->data(), dim, strides->data());

auto tensor_impl = std::make_shared<TensorImpl>(
static_cast<torch::executor::ScalarType>(value_.scalar_type()),
sizes->size(),
sizes->data(),
value_.mutable_data_ptr(),
dim_order->data(),
strides->data());

converted_ = std::unique_ptr<Tensor, std::function<void(Tensor*)>>(
new Tensor(tensor_impl.get()),
[sizes, dim_order, strides, tensor_impl](Tensor* pointer) {
delete pointer;
});
}

ETensor call() {
return converted;
return *converted_;
}

private:
ATensor value_;
std::unique_ptr<Tensor, std::function<void(Tensor*)>> converted_;
};

// Tensors: ETen to ATen.
Expand All @@ -139,21 +153,22 @@ struct type_convert<
typename remove_const_ref<ETensor>::type,
torch::executor::Tensor>>>
final {
public:
ETensor val;
at::Tensor converted;
std::vector<int64_t> sizes;
explicit type_convert(ETensor value) : val(value) {
for (auto size : val.sizes()) {
sizes.push_back(size);
}
c10::ScalarType scalar_type =
static_cast<c10::ScalarType>(val.scalar_type());
converted = at::from_blob(val.mutable_data_ptr(), sizes, scalar_type);
explicit type_convert(ETensor value)
: value_(value), sizes_(value_.sizes().begin(), value_.sizes().end()) {
converted_ = at::from_blob(
value_.mutable_data_ptr(),
sizes_,
static_cast<c10::ScalarType>(value_.scalar_type()));
}

ATensor call() {
return converted;
return converted_;
}

private:
ETensor value_;
at::Tensor converted_;
std::vector<int64_t> sizes_;
};

// Optionals: ATen to ETen.
Expand Down
1 change: 0 additions & 1 deletion extension/aten_util/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def define_common_targets():
],
exported_deps = [
"//executorch/extension/kernel_util:kernel_util",
"//executorch/extension/runner_util:managed_tensor",
"//executorch/runtime/core:core",
"//executorch/runtime/core:evalue",
"//executorch/runtime/core/exec_aten:lib",
Expand Down
31 changes: 9 additions & 22 deletions extension/runner_util/managed_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,39 +37,29 @@ class ManagedTensor {
using DimOrderType = exec_aten::DimOrderType;
/// The type used for elements of `strides()`.
using StridesType = exec_aten::StridesType;

ManagedTensor() = delete;

explicit ManagedTensor(
void* data,
const std::vector<SizesType>& sizes,
ScalarType dtype)
: dtype_(dtype), sizes_(sizes), data_ptr_(data) {
: sizes_(sizes) {
#ifdef USE_ATEN_LIB
tensor_ = torch::from_blob(data, sizes, dtype_);
tensor_ = torch::from_blob(data, sizes, dtype);
#else
ssize_t dim = sizes.size();
dim_order_.resize(dim);
strides_.resize(dim);
for (size_t i = 0; i < dim; ++i) {
dim_order_[i] = i;
}
dim_order_to_stride_nocheck(
sizes.data(), dim_order_.data(), dim, strides_.data());
tensor_impl_ = std::make_unique<TensorImpl>(
dtype_,
dim,
dtype,
sizes_.size(),
sizes_.data(),
data_ptr_,
dim_order_.data(),
strides_.data(),
data,
nullptr,
nullptr,
TensorShapeDynamism::DYNAMIC_BOUND);
#endif
}

void resize(const std::vector<SizesType>& new_sizes) {
ET_CHECK_MSG(
new_sizes.size() == sizes_.size(),
"Cannot change rank of a managed tensor");
auto err = resize_tensor(
this->get_aliasing_tensor(),
exec_aten::ArrayRef<SizesType>(new_sizes.data(), new_sizes.size()));
Expand All @@ -88,15 +78,12 @@ class ManagedTensor {
}

private:
ScalarType dtype_;
std::unique_ptr<TensorImpl> tensor_impl_;
std::vector<SizesType> sizes_;
std::vector<StridesType> strides_;
std::vector<DimOrderType> dim_order_;
void* data_ptr_ = nullptr;
#ifdef USE_ATEN_LIB
Tensor tensor_;
#endif
};

} // namespace executor
} // namespace torch
27 changes: 0 additions & 27 deletions extension/runner_util/test/managed_tensor_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,6 @@ TEST_F(ManagedTensorTest, Smoke) {

EXPECT_EQ(tensor.sizes(), ArrayRef<SizesType>(sizes_.data(), sizes_.size()));
EXPECT_EQ(tensor.scalar_type(), ScalarType::Long);
std::vector<DimOrderType> expected_dim_order = {0, 1};
EXPECT_EQ(
tensor.dim_order(),
ArrayRef<DimOrderType>(
expected_dim_order.data(), expected_dim_order.size()));
std::vector<StridesType> expected_strides = {3, 1};
EXPECT_EQ(
tensor.strides(),
ArrayRef<StridesType>(expected_strides.data(), expected_strides.size()));
EXPECT_EQ(tensor.const_data_ptr(), data_.data());
}

Expand All @@ -74,15 +65,6 @@ TEST_F(ManagedTensorTest, ResizeShrink) {
tensor.sizes(),
ArrayRef<SizesType>(expected_sizes.data(), expected_sizes.size()));
EXPECT_EQ(tensor.scalar_type(), ScalarType::Long);
std::vector<DimOrderType> expected_dim_order = {0, 1};
EXPECT_EQ(
tensor.dim_order(),
ArrayRef<DimOrderType>(
expected_dim_order.data(), expected_dim_order.size()));
std::vector<StridesType> expected_strides = {2, 1};
EXPECT_EQ(
tensor.strides(),
ArrayRef<StridesType>(expected_strides.data(), expected_strides.size()));
EXPECT_EQ(tensor.const_data_ptr(), data_.data());
}

Expand All @@ -95,14 +77,5 @@ TEST_F(ManagedTensorTest, Resize) {
tensor.sizes(),
ArrayRef<SizesType>(expected_sizes.data(), expected_sizes.size()));
EXPECT_EQ(tensor.scalar_type(), ScalarType::Long);
std::vector<DimOrderType> expected_dim_order = {0, 1};
EXPECT_EQ(
tensor.dim_order(),
ArrayRef<DimOrderType>(
expected_dim_order.data(), expected_dim_order.size()));
std::vector<StridesType> expected_strides = {2, 1};
EXPECT_EQ(
tensor.strides(),
ArrayRef<StridesType>(expected_strides.data(), expected_strides.size()));
EXPECT_EQ(tensor.const_data_ptr(), data_.data());
}