Skip to content
17 changes: 8 additions & 9 deletions aten/src/ATen/DLConvertor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,21 +67,20 @@ static DLContext getDLContext(const Tensor& tensor, const int64_t& device_id) {
return ctx;
}

static DeviceType getATenDeviceType(const DLContext& ctx) {
static Device getATenDevice(const DLContext& ctx) {
switch (ctx.device_type) {
case DLDeviceType::kDLCPU:
return DeviceType::CPU;
return at::Device(DeviceType::CPU);
case DLDeviceType::kDLGPU:
return DeviceType::CUDA;
return at::Device(DeviceType::CUDA, ctx.device_id);
case DLDeviceType::kDLOpenCL:
return DeviceType::OPENCL;
return at::Device(DeviceType::OPENCL, ctx.device_id);
case DLDeviceType::kDLROCM:
return DeviceType::HIP;
return at::Device(DeviceType::HIP, ctx.device_id);
default:
throw std::logic_error(
"Unsupported device_type: " + std::to_string(ctx.device_type));
}
return DeviceType::CPU; // impossible
}

ScalarType toScalarType(const DLDataType& dtype) {
Expand Down Expand Up @@ -173,7 +172,7 @@ DLManagedTensor* toDLPack(const Tensor& src) {
}

Tensor fromDLPack(const DLManagedTensor* src) {
DeviceType device_type = getATenDeviceType(src->dl_tensor.ctx);
Device device = getATenDevice(src->dl_tensor.ctx);
ScalarType stype = toScalarType(src->dl_tensor.dtype);
auto deleter = [src](void* self) {
src->deleter(const_cast<DLManagedTensor*>(src));
Expand All @@ -182,14 +181,14 @@ Tensor fromDLPack(const DLManagedTensor* src) {
return at::from_blob(src->dl_tensor.data,
IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim),
deleter,
at::device(device_type).dtype(stype));
at::device(device).dtype(stype));
}

return at::from_blob(
src->dl_tensor.data,
IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim),
IntArrayRef(src->dl_tensor.strides, src->dl_tensor.ndim),
deleter,
at::device(device_type).dtype(stype));
at::device(device).dtype(stype));
}
} // namespace at
6 changes: 4 additions & 2 deletions aten/src/ATen/SparseTensorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,10 @@ inline LongTensor flatten_indices(const Tensor& indices, IntArrayRef full_size,
indices_mult_cpu_vec[i] = mult;
mult *= full_size[i];
}
auto indices_mult_cpu = indices.dispatch_type().cpu()
.tensorFromBlob(indices_mult_cpu_vec.data(), /*size=*/{sparse_dim, 1});
auto indices_mult_cpu = at::from_blob(
indices_mult_cpu_vec.data(),
/*size=*/{sparse_dim, 1},
indices.options().device(kCPU));
// NB: must be blocking because this blob may be freed after this closure,
// and non_blocking copy will see garbage.
auto indices_mult = indices_mult_cpu.to(indices.device(), /*non_blocking=*/false);
Expand Down
25 changes: 25 additions & 0 deletions aten/src/ATen/TensorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,4 +235,29 @@ bool geometry_is_contiguous(IntArrayRef sizes, IntArrayRef strides) {
return contig_if_nonempty;
}

namespace detail {

std::vector<int64_t> defaultStrides(IntArrayRef sizes) {
std::vector<int64_t> strides(sizes.size());
int64_t stride = 1;
for(size_t i = sizes.size(); i > 0; --i) {
strides[i-1] = stride;
stride *= sizes[i-1];
}
return strides;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is duplicate with this method in TensorImpl.h

  inline void update_to_contiguous_strides(size_t old_dim) {
    strides_.resize(sizes_.size(), 0);
    if (dim() > 0) {
      int last_idx = dim() - 1;
      strides_[last_idx] = 1;
      for (auto i = last_idx - 1; i >= 0; --i) {
        strides_[i] = strides_[i + 1] * std::max<int64_t>(sizes_[i + 1], 1);
      }
    }
    is_contiguous_ = true;
  }

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the code existed from before. I'll accept the movement as long a we hide these methods in a private namespace (e.g., impl)


int64_t computeStorageSize(IntArrayRef sizes, IntArrayRef strides) {
// size of the underlying storage is 1 bigger than the offset
// of the last element according to stride
int64_t size = 1;
for(size_t i = 0; i < sizes.size(); i++) {
if(sizes[i] == 0) {
return 0;
}
size += strides[i]*(sizes[i]-1);
}
return size;
}
} // namespace detail
} // namespace at
7 changes: 6 additions & 1 deletion aten/src/ATen/TensorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,9 @@ CAFFE2_API void* maybe_data_ptr(const TensorArg& tensor);
// constructing a tensor, e.g., when you want to choose a kernel strategy based
// on whether a subgeometry is contiguous.
CAFFE2_API bool geometry_is_contiguous(IntArrayRef sizes, IntArrayRef strides);
}

namespace detail {
CAFFE2_API std::vector<int64_t> defaultStrides(IntArrayRef sizes);
CAFFE2_API int64_t computeStorageSize(IntArrayRef sizes, IntArrayRef strides);
} // namespace detail
} // namespace at
3 changes: 0 additions & 3 deletions aten/src/ATen/UndefinedType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ Device UndefinedType::getDeviceFromPtr(void*) const {
AT_ERROR("getDeviceFromPtr not defined for UndefinedType");
}

Storage UndefinedType::storageFromBlob(void * data, int64_t size, const std::function<void(void*)> & deleter) const {
AT_ERROR("storageFromBlob not defined for UndefinedType");
}
Storage UndefinedType::unsafeStorageFromTH(void * th_pointer, bool retain) const {
AT_ERROR("unsafeStorageFromTH not defined for UndefinedType");
}
Expand Down
1 change: 0 additions & 1 deletion aten/src/ATen/UndefinedType.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ struct UndefinedType final : public TypeDefault {
virtual Backend backend() const override;
virtual Allocator* allocator() const override;
virtual Device getDeviceFromPtr(void* data) const override;
virtual Storage storageFromBlob(void * data, int64_t size, const std::function<void(void*)> & deleter) const override;
virtual Storage storageWithAllocator(int64_t size, Allocator* allocator) const override;
virtual std::unique_ptr<Generator> generator() const override;
virtual const char * toString() const override;
Expand Down
3 changes: 0 additions & 3 deletions aten/src/ATen/core/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ struct CAFFE2_API Type {
bool is_undefined() const noexcept { return is_undefined_; }
virtual Allocator * allocator() const = 0;
virtual Device getDeviceFromPtr(void * data) const = 0;
virtual Storage storageFromBlob(void * data, int64_t size, const std::function<void(void*)> & deleter=noop_deleter) const = 0;
virtual Storage storageWithAllocator(int64_t size, Allocator* allocator) const = 0;
virtual std::unique_ptr<Generator> generator() const = 0;
virtual Tensor unsafeTensorFromTH(void * th_pointer, bool retain) const = 0;
Expand Down Expand Up @@ -176,8 +175,6 @@ struct CAFFE2_API Type {
bool create_graph) const = 0;
virtual void set_data(Tensor & self, Tensor new_data) const = 0;

virtual Tensor tensorFromBlob(void * data, IntArrayRef sizes, const std::function<void(void*)> & deleter=noop_deleter) const = 0;
virtual Tensor tensorFromBlob(void * data, IntArrayRef sizes, IntArrayRef strides, const std::function<void(void*)> & deleter=noop_deleter) const = 0;
virtual Tensor tensorWithAllocator(IntArrayRef sizes, Allocator* allocator) const = 0;
virtual Tensor tensorWithAllocator(IntArrayRef sizes, IntArrayRef strides, Allocator* allocator) const = 0;

Expand Down
13 changes: 1 addition & 12 deletions aten/src/ATen/native/Resize.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,12 @@ inline TensorImpl* resize_impl_cpu_(
return self;
}

static inline int64_t computeStorageSize(IntArrayRef sizes, IntArrayRef strides) {
int64_t storage_size = 1;
for (size_t dim = 0; dim < sizes.size(); ++dim) {
if (sizes[dim] == 0) {
return 0;
}
storage_size += strides[dim] * (sizes[dim] - 1);
}
return storage_size;
}

static inline void checkInBoundsForStorage(
IntArrayRef size,
IntArrayRef stride,
int64_t storage_offset,
const Storage& new_storage) {
int64_t storage_size = computeStorageSize(size, stride);
int64_t storage_size = detail::computeStorageSize(size, stride);
if (storage_size == 0) {
// NB: (a tensor with arbitrary 0 dims)'s storage can have any numel.
return;
Expand Down
9 changes: 6 additions & 3 deletions aten/src/ATen/native/cuda/TensorTransformations.cu
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,16 @@ Tensor flip_cuda(const Tensor& self, IntArrayRef dims) {
return out_tensor;
}

auto flip_dims_t = at::CPU(kLong).tensorFromBlob(flip_dims.data(), {static_cast<int64_t>(flip_dims.size())});
auto flip_dims_t = at::from_blob(
flip_dims.data(), {static_cast<int64_t>(flip_dims.size())}, at::device(kCPU).dtype(kLong));

auto shape = in_tensor.sizes().vec();
auto shape_t = at::CPU(kLong).tensorFromBlob(shape.data(), {static_cast<int64_t>(shape.size())});
auto shape_t = at::from_blob(
shape.data(), {static_cast<int64_t>(shape.size())}, at::device(kCPU).dtype(kLong));

auto strides = in_tensor.strides().vec();
auto strides_t = at::CPU(kLong).tensorFromBlob(strides.data(), {static_cast<int64_t>(strides.size())});
auto strides_t = at::from_blob(
strides.data(), {static_cast<int64_t>(strides.size())}, at::device(kCPU).dtype(kLong));

// stride_contiguous is the stride of non-contiguous tensor after calling contiguous(),
// it is used to compute indices for each element in non-contiguous tensor
Expand Down
41 changes: 40 additions & 1 deletion aten/src/ATen/templates/Functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,53 @@
#include <c10/core/TensorOptions.h>
#include <ATen/core/Reduction.h>
#include <c10/util/Optional.h>
#include <ATen/TensorUtils.h>

namespace at {

using native::from_blob;
using native::tensor;

${function_declarations}

inline Tensor from_blob(
void* data,
IntArrayRef sizes,
IntArrayRef strides,
const std::function<void(void*)>& deleter,
const TensorOptions& options = {}) {
auto storage = Storage(
options.dtype(),
detail::computeStorageSize(sizes, strides),
InefficientStdFunctionContext::makeDataPtr(
data, deleter, options.device()),
/*allocator=*/nullptr,
/*resizable=*/false);
return empty({0}, options).set_(storage, 0, sizes, strides);
}

inline Tensor from_blob(
void* data,
IntArrayRef sizes,
const std::function<void(void*)>& deleter,
const TensorOptions& options = {}) {
return from_blob(data, sizes, detail::defaultStrides(sizes), deleter, options);
}

inline Tensor from_blob(
void* data,
IntArrayRef sizes,
IntArrayRef strides,
const TensorOptions& options = {}) {
return from_blob(data, sizes, strides, [](void*) {}, options);
}

inline Tensor from_blob(
void* data,
IntArrayRef sizes,
const TensorOptions& options = {}) {
return from_blob(data, sizes, detail::defaultStrides(sizes), [](void*) {}, options);
}

namespace detail {

static inline TypeExtendedInterface & infer_type(const Tensor & t) {
Expand Down
17 changes: 0 additions & 17 deletions aten/src/ATen/templates/NativeFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,6 @@ struct Type;
namespace at {
namespace native {

inline Tensor from_blob(
void* data,
IntArrayRef sizes,
const std::function<void(void*)>& deleter,
const TensorOptions& options = {}) {
return at::getType(options).tensorFromBlob(data, sizes, deleter);
}

inline Tensor from_blob(
void* data,
IntArrayRef sizes,
IntArrayRef strides,
const std::function<void(void*)>& deleter,
const TensorOptions& options = {}) {
return at::getType(options).tensorFromBlob(data, sizes, strides, deleter);
}

// These functions are defined in native/TensorFactories.cpp.
#define TENSOR(T, S, _1) \
CAFFE2_API Tensor tensor(ArrayRef<T> values, const TensorOptions& options); \
Expand Down
3 changes: 0 additions & 3 deletions aten/src/ATen/templates/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ struct CAFFE2_API Type {
bool is_undefined() const noexcept { return is_undefined_; }
virtual Allocator * allocator() const = 0;
virtual Device getDeviceFromPtr(void * data) const = 0;
virtual Storage storageFromBlob(void * data, int64_t size, const std::function<void(void*)> & deleter=noop_deleter) const = 0;
virtual Storage storageWithAllocator(int64_t size, Allocator* allocator) const = 0;
virtual std::unique_ptr<Generator> generator() const = 0;
virtual Tensor unsafeTensorFromTH(void * th_pointer, bool retain) const = 0;
Expand Down Expand Up @@ -119,8 +118,6 @@ struct CAFFE2_API Type {
bool create_graph) const = 0;
virtual void set_data(Tensor & self, Tensor new_data) const = 0;

virtual Tensor tensorFromBlob(void * data, IntArrayRef sizes, const std::function<void(void*)> & deleter=noop_deleter) const = 0;
virtual Tensor tensorFromBlob(void * data, IntArrayRef sizes, IntArrayRef strides, const std::function<void(void*)> & deleter=noop_deleter) const = 0;
virtual Tensor tensorWithAllocator(IntArrayRef sizes, Allocator* allocator) const = 0;
virtual Tensor tensorWithAllocator(IntArrayRef sizes, IntArrayRef strides, Allocator* allocator) const = 0;

Expand Down
41 changes: 2 additions & 39 deletions aten/src/ATen/templates/TypeDefault.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,51 +57,14 @@ Type & TypeDefault::toBackend(Backend b) const {
Type & TypeDefault::toScalarType(ScalarType s) const {
return at::globalContext().getNonVariableType(backend(),s);
}
static std::vector<int64_t> defaultStrides(IntArrayRef sizes) {
std::vector<int64_t> strides(sizes.size());
int64_t stride = 1;
for(size_t i = sizes.size(); i > 0; --i) {
strides[i-1] = stride;
stride *= sizes[i-1];
}
return strides;
}
static int64_t computeStorageSize(IntArrayRef sizes, IntArrayRef strides) {
// size of the underlying storage is 1 bigger than the offset
// of the last element according to stride
int64_t size = 1;
for(size_t i = 0; i < sizes.size(); i++) {
if(sizes[i] == 0) {
return 0;
}
size += strides[i]*(sizes[i]-1);
}
return size;
}
Tensor TypeDefault::tensorFromBlob(void * data, IntArrayRef sizes, const std::function<void(void*)> & deleter) const {
return tensorFromBlob(data, sizes, defaultStrides(sizes), deleter);
}
Tensor TypeDefault::tensorFromBlob(void * data, IntArrayRef sizes, IntArrayRef strides, const std::function<void(void*)> & deleter) const {
auto storage = storageFromBlob(data, computeStorageSize(sizes, strides), deleter);
return at::empty({0}, options()).set_(storage, 0, sizes, strides);
}
Tensor TypeDefault::tensorWithAllocator(IntArrayRef sizes, Allocator* allocator) const {
return tensorWithAllocator(sizes, defaultStrides(sizes), std::move(allocator));
return tensorWithAllocator(sizes, detail::defaultStrides(sizes), std::move(allocator));
}
Tensor TypeDefault::tensorWithAllocator(IntArrayRef sizes, IntArrayRef strides, Allocator* allocator) const {
auto storage = storageWithAllocator(computeStorageSize(sizes, strides), std::move(allocator));
auto storage = storageWithAllocator(detail::computeStorageSize(sizes, strides), std::move(allocator));
return at::empty({0}, options()).set_(storage, 0, sizes, strides);
}

Storage TypeDefault::storageFromBlob(void * data, int64_t size, const std::function<void(void*)> & deleter) const {
return Storage(
typeMeta(),
size,
InefficientStdFunctionContext::makeDataPtr(
data, deleter, getDeviceFromPtr(data)),
/*allocator=*/nullptr,
/*resizable=*/false);
}
Storage TypeDefault::storageWithAllocator(int64_t size, Allocator* allocator) const {
// Potentially the storage might be marked as resizable too here
return Storage(typeMeta(), size, allocator, /*resizable=*/false);
Expand Down
3 changes: 0 additions & 3 deletions aten/src/ATen/templates/TypeDefault.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,9 @@ struct CAFFE2_API TypeDefault : public TypeExtendedInterface {
bool create_graph) const override;
void set_data(Tensor & self, Tensor new_data) const override;

Tensor tensorFromBlob(void * data, IntArrayRef sizes, const std::function<void(void*)> & deleter=noop_deleter) const override;
Tensor tensorFromBlob(void * data, IntArrayRef sizes, IntArrayRef strides, const std::function<void(void*)> & deleter=noop_deleter) const override;
Tensor tensorWithAllocator(IntArrayRef sizes, Allocator* allocator) const override;
Tensor tensorWithAllocator(IntArrayRef sizes, IntArrayRef strides, Allocator* allocator) const override;

Storage storageFromBlob(void * data, int64_t size, const std::function<void(void*)> & deleter) const override;
Storage storageWithAllocator(int64_t size, Allocator* allocator) const override;
Storage unsafeStorageFromTH(void * th_pointer, bool retain) const override;
Tensor unsafeTensorFromTH(void * th_pointer, bool retain) const override;
Expand Down
9 changes: 4 additions & 5 deletions aten/src/ATen/test/atest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ TEST(atest, atest) {

float data[] = {1, 2, 3, 4, 5, 6};

auto f = CPU(kFloat).tensorFromBlob(data, {1, 2, 3});
auto f = from_blob(data, {1, 2, 3});
auto f_a = f.accessor<float, 3>();

ASSERT_EQ(f_a[0][0][0], 1.0);
Expand All @@ -72,7 +72,7 @@ TEST(atest, atest) {
int isgone = 0;
{
auto f2 =
CPU(kFloat).tensorFromBlob(data, {1, 2, 3}, [&](void*) { isgone++; });
from_blob(data, {1, 2, 3}, [&](void*) { isgone++; });
}
ASSERT_EQ(isgone, 1);
}
Expand All @@ -81,7 +81,7 @@ TEST(atest, atest) {
Tensor a_view;
{
auto f2 =
CPU(kFloat).tensorFromBlob(data, {1, 2, 3}, [&](void*) { isgone++; });
from_blob(data, {1, 2, 3}, [&](void*) { isgone++; });
a_view = f2.view({3, 2, 1});
}
ASSERT_EQ(isgone, 0);
Expand All @@ -93,8 +93,7 @@ TEST(atest, atest) {
int isgone = 0;
{
auto base = at::empty({1,2,3}, TensorOptions(kCUDA));
auto f2 = CUDA(kFloat).tensorFromBlob(
base.data_ptr(), {1, 2, 3}, [&](void*) { isgone++; });
auto f2 = from_blob(base.data_ptr(), {1, 2, 3}, [&](void*) { isgone++; });
}
ASSERT_EQ(isgone, 1);
}
Expand Down
Loading