Skip to content

Commit

Permalink
Change StorageImpl to track byte count rather than element count (#37776
Browse files Browse the repository at this point in the history
)

Summary:
Pull Request resolved: #37776

* Remove type-specific size tracking in favor of byte size tracking in Storage and StorageImpl
* Changed numel() and set_numel() to nbytes() and set_nbytes()
* Added enum argument to Storage/StorageImpl constructor to indicate new meaning of the size parameter
* Update all callers of the changed API

Part of issue #33950
Pull Request resolved: #37028

Differential Revision: D21171334

Pulled By: ezyang

fbshipit-source-id: 37329a379de9a3a83cc5e9007e455a3e1c2d10b8
  • Loading branch information
kurtamohler authored and facebook-github-bot committed May 5, 2020
1 parent 25ba802 commit 3706803
Show file tree
Hide file tree
Showing 49 changed files with 782 additions and 589 deletions.
9 changes: 6 additions & 3 deletions aten/src/ATen/TensorUtils.cpp
Expand Up @@ -298,17 +298,20 @@ std::vector<int64_t> defaultStrides(IntArrayRef sizes) {
return strides;
}

int64_t computeStorageSize(IntArrayRef sizes, IntArrayRef strides) {
size_t computeStorageNbytes(
IntArrayRef sizes,
IntArrayRef strides,
size_t itemsize_bytes) {
// size of the underlying storage is 1 bigger than the offset
// of the last element according to stride
int64_t size = 1;
size_t size = 1;
for(size_t i = 0; i < sizes.size(); i++) {
if(sizes[i] == 0) {
return 0;
}
size += strides[i]*(sizes[i]-1);
}
return size;
return size * itemsize_bytes;
}

// On a high level,
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/TensorUtils.h
Expand Up @@ -144,7 +144,8 @@ CAFFE2_API void check_dim_size(

namespace detail {
CAFFE2_API std::vector<int64_t> defaultStrides(IntArrayRef sizes);
CAFFE2_API int64_t computeStorageSize(IntArrayRef sizes, IntArrayRef strides);
CAFFE2_API size_t
computeStorageNbytes(IntArrayRef sizes, IntArrayRef strides, size_t itemsize);
CAFFE2_API c10::optional<std::vector<int64_t>> computeStride(
IntArrayRef oldshape,
IntArrayRef oldstride,
Expand Down
12 changes: 7 additions & 5 deletions aten/src/ATen/core/boxing/impl/test_helpers.h
Expand Up @@ -17,12 +17,14 @@ inline at::Tensor dummyTensor(c10::DispatchKeySet ks) {
auto* allocator = c10::GetCPUAllocator();
int64_t nelements = 1;
auto dtype = caffe2::TypeMeta::Make<float>();
int64_t size_bytes = nelements * dtype.itemsize();
auto storage_impl = c10::make_intrusive<c10::StorageImpl>(
dtype,
nelements,
allocator->allocate(nelements * dtype.itemsize()),
allocator,
/*resizable=*/true);
c10::StorageImpl::use_byte_size_t(),
dtype,
size_bytes,
allocator->allocate(size_bytes),
allocator,
/*resizable=*/true);
return at::detail::make_tensor<c10::TensorImpl>(storage_impl, ks);
}

Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/function_wrapper.py
Expand Up @@ -337,16 +337,16 @@ def __init__(self, reason):

ALLOC_NOARGS_WRAP = {
'THTensor*': 'c10::make_intrusive<TensorImpl, UndefinedTensorImpl>'
'(c10::Storage(scalarTypeToTypeMeta(${ScalarName}), 0, allocator(), true),'
'(c10::Storage(c10::Storage::use_byte_size_t(), scalarTypeToTypeMeta(${ScalarName}), 0, allocator(), true),'
'DispatchKey::${Backend}).release()',
'THByteTensor*': 'c10::make_intrusive<TensorImpl, UndefinedTensorImpl>'
'(c10::Storage(scalarTypeToTypeMeta(ScalarType::Byte), 0, allocator(), true),'
'(c10::Storage(c10::Storage::use_byte_size_t(), scalarTypeToTypeMeta(ScalarType::Byte), 0, allocator(), true),'
'DispatchKey::${Backend}).release()',
'THBoolTensor*': 'c10::make_intrusive<TensorImpl, UndefinedTensorImpl>'
'(c10::Storage(scalarTypeToTypeMeta(ScalarType::Bool), 0, allocator(), true),'
'(c10::Storage(c10::Storage::use_byte_size_t(), scalarTypeToTypeMeta(ScalarType::Bool), 0, allocator(), true),'
'DispatchKey::${Backend}).release()',
'THIndexTensor*': 'c10::make_intrusive<TensorImpl, UndefinedTensorImpl>'
'(c10::Storage(scalarTypeToTypeMeta(ScalarType::Long), 0, allocator(), true),'
'(c10::Storage(c10::Storage::use_byte_size_t(), scalarTypeToTypeMeta(ScalarType::Long), 0, allocator(), true),'
'DispatchKey::${Backend}).release()',
}

Expand Down
7 changes: 4 additions & 3 deletions aten/src/ATen/native/Memory.cpp
Expand Up @@ -22,11 +22,12 @@ Tensor pin_memory(const Tensor& self) {
}
auto* allocator = detail::getCUDAHooks().getPinnedMemoryAllocator();
auto storage = Storage(
Storage::use_byte_size_t(),
self.dtype(),
detail::computeStorageSize(self.sizes(), self.strides()),
detail::computeStorageNbytes(
self.sizes(), self.strides(), self.dtype().itemsize()),
allocator,
/*resizable=*/false
);
/*resizable=*/false);
auto tensor = at::empty({0}, self.options()).set_(storage, 0, self.sizes(), self.strides());
tensor.copy_(self);
return tensor;
Expand Down
39 changes: 26 additions & 13 deletions aten/src/ATen/native/Resize.h
Expand Up @@ -19,10 +19,10 @@ static inline void maybe_resize_storage_cpu(TensorImpl* self, int64_t new_size)
if (!THTensor_getStoragePtr(self)) {
THTensor_stealAndSetStoragePtr(self, THStorage_new(self->dtype()));
}
if (new_size + self->storage_offset() > self->storage().numel()) {
THStorage_resize(
THTensor_getStoragePtr(self),
new_size + self->storage_offset());
int64_t new_size_bytes =
(new_size + self->storage_offset()) * self->dtype().itemsize();
if (new_size_bytes > self->storage().nbytes()) {
THStorage_resizeBytes(THTensor_getStoragePtr(self), new_size_bytes);
}
}
}
Expand Down Expand Up @@ -61,19 +61,31 @@ static inline void checkInBoundsForStorage(
IntArrayRef size,
IntArrayRef stride,
int64_t storage_offset,
const caffe2::TypeMeta& data_type,
const Storage& new_storage) {
int64_t storage_size = detail::computeStorageSize(size, stride);
if (storage_size == 0) {
int64_t storage_size_bytes =
detail::computeStorageNbytes(size, stride, data_type.itemsize());
int64_t storage_offset_bytes = storage_offset * data_type.itemsize();
if (storage_size_bytes == 0) {
// NB: (a tensor with arbitrary 0 dims)'s storage can have any numel.
return;
}
int64_t new_storage_size = new_storage.numel();
int64_t new_storage_size_bytes = new_storage.nbytes();
TORCH_CHECK(
storage_offset + storage_size <= new_storage_size,
"setStorage: sizes ", size, ", strides ", stride, ","
" and storage offset ", storage_offset,
" requiring a storage size of ", storage_size + storage_offset,
" are out of bounds for storage with numel ", new_storage_size);
storage_size_bytes + storage_offset_bytes <= new_storage_size_bytes,
"setStorage: sizes ",
size,
", strides ",
stride,
","
" storage offset ",
storage_offset,
", and itemsize ",
data_type.itemsize(),
" requiring a storage size of ",
storage_size_bytes,
" are out of bounds for storage of size ",
new_storage_size_bytes);
}

static inline void checkSetStorage(Tensor& result, Storage storage, int64_t storage_offset,
Expand Down Expand Up @@ -124,7 +136,8 @@ inline void setStrided(
IntArrayRef stride,
int64_t storage_offset) {
auto* self_ = self.unsafeGetTensorImpl();
checkInBoundsForStorage(size, stride, storage_offset, self_->storage());
checkInBoundsForStorage(
size, stride, storage_offset, self_->dtype(), self_->storage());

/* storage offset */
TORCH_CHECK(storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
Expand Down
30 changes: 17 additions & 13 deletions aten/src/ATen/native/TensorFactories.cpp
Expand Up @@ -119,12 +119,14 @@ Tensor empty_cpu(IntArrayRef size, const TensorOptions& options_, c10::optional<

int64_t nelements = prod_intlist(size);
auto dtype = options.dtype();
int64_t size_bytes = nelements * dtype.itemsize();
auto storage_impl = c10::make_intrusive<StorageImpl>(
dtype,
nelements,
allocator->allocate(nelements * dtype.itemsize()),
allocator,
/*resizeable=*/true);
c10::StorageImpl::use_byte_size_t(),
dtype,
size_bytes,
allocator->allocate(size_bytes),
allocator,
/*resizeable=*/true);

auto tensor = detail::make_tensor<TensorImpl>(std::move(storage_impl), at::DispatchKey::CPU);
// Default TensorImpl has size [0]
Expand Down Expand Up @@ -976,18 +978,20 @@ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR)

Tensor from_file(std::string filename, c10::optional<bool> shared, c10::optional<int64_t> size, const TensorOptions& options) {
TORCH_CHECK(!options.pinned_memory(), "tensors constructed from a file cannot be pinned");
size_t my_size = size.value_or(0);
int64_t my_size = size.value_or(0);
int flags = shared.value_or(false) ? TH_ALLOCATOR_MAPPED_SHARED : 0;
auto dtype = options.dtype();
size_t size_bytes = my_size * dtype.itemsize();
auto storage_impl = c10::make_intrusive<at::StorageImpl>(
dtype,
my_size,
THMapAllocator::makeDataPtr(
filename.c_str(), flags, my_size * dtype.itemsize(), nullptr),
/*allocator=*/nullptr,
/*resizable=*/false);
c10::StorageImpl::use_byte_size_t(),
dtype,
size_bytes,
THMapAllocator::makeDataPtr(
filename.c_str(), flags, size_bytes, nullptr),
/*allocator=*/nullptr,
/*resizable=*/false);
auto tensor = detail::make_tensor<at::TensorImpl>(storage_impl, at::DispatchKey::CPU);
tensor.unsafeGetTensorImpl()->set_sizes_contiguous({storage_impl->numel()});
tensor.unsafeGetTensorImpl()->set_sizes_contiguous({my_size});
return tensor;
}

Expand Down
11 changes: 9 additions & 2 deletions aten/src/ATen/native/TensorShape.cpp
Expand Up @@ -39,7 +39,9 @@ Tensor _shape_as_tensor(const Tensor& self) {
}

Tensor& set_(Tensor& result, Storage source) {
return result.set_(source, 0, static_cast<int64_t>(source.size()), {});
int64_t new_size =
static_cast<int64_t>(source.nbytes() / result.dtype().itemsize());
return result.set_(source, 0, new_size, {});
}

// unify with cuda implementation? This is not done to avoid a dispatch in resize_impl_cpu_
Expand All @@ -64,7 +66,12 @@ Tensor& set_tensor_(Tensor& result, const Tensor& source) {
// way of getting the allocator to use for a device (c10::GetAllocator is not
// the same as at::cuda::getCUDADeviceAllocator().
Tensor& set_cpu_(Tensor& result) {
Storage storage(result.dtype(), 0, c10::GetAllocator(kCPU), true);
Storage storage(
Storage::use_byte_size_t(),
result.dtype(),
0,
c10::GetAllocator(kCPU),
true);
return result.set_(storage, 0, {0}, {});
}

Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cuda/MiscUtils.h
Expand Up @@ -80,11 +80,11 @@ static inline Storage pin_memory(int64_t size) {
auto* allocator = cuda::getPinnedMemoryAllocator();
int64_t adjusted_size = size * sizeof(T);
return Storage(
Storage::use_byte_size_t(),
caffe2::TypeMeta::Make<uint8_t>(),
adjusted_size,
allocator,
/*resizable=*/false
);
/*resizable=*/false);
}

} // namespace native
Expand Down
8 changes: 5 additions & 3 deletions aten/src/ATen/native/cuda/Resize.cuh
Expand Up @@ -21,11 +21,13 @@ static inline void maybe_resize_storage_cuda(TensorImpl* self, int64_t new_size)
if (!THTensor_getStoragePtr(self)) {
AT_ERROR("Tensor: invalid null storage");
}
if (new_size + self->storage_offset() > self->storage().numel()) {
THCStorage_resize(
uint64_t new_size_bytes = (new_size + self->storage_offset()) * self->dtype().itemsize();
if (new_size_bytes > self->storage().nbytes()) {
THCStorage_resizeBytes(
globalContext().getTHCState(),
THTensor_getStoragePtr(self),
new_size + self->storage_offset());
new_size_bytes
);
}
}
}
Expand Down
12 changes: 7 additions & 5 deletions aten/src/ATen/native/cuda/TensorFactories.cu
Expand Up @@ -52,12 +52,14 @@ Tensor empty_cuda(IntArrayRef size, const TensorOptions& options, c10::optional<
auto* allocator = at::cuda::getCUDADeviceAllocator();
int64_t nelements = prod_intlist(size);
auto dtype = options.dtype();
int64_t size_bytes = nelements * dtype.itemsize();
auto storage_impl = c10::make_intrusive<StorageImpl>(
dtype,
nelements,
allocator->allocate(nelements * dtype.itemsize()),
allocator,
/*resizeable=*/true);
c10::StorageImpl::use_byte_size_t(),
dtype,
size_bytes,
allocator->allocate(size_bytes),
allocator,
/*resizeable=*/true);

auto tensor = detail::make_tensor<TensorImpl>(storage_impl, DispatchKey::CUDA);
// Default TensorImpl has size [0]
Expand Down
7 changes: 6 additions & 1 deletion aten/src/ATen/native/cuda/TensorShapeCUDA.cpp
Expand Up @@ -11,7 +11,12 @@ namespace native {
// way of getting the allocator to use for a device (c10::GetAllocator is not
// the same as at::cuda::getCUDADeviceAllocator().
Tensor& set_cuda_(Tensor& result) {
Storage storage(result.dtype(), 0, at::cuda::getCUDADeviceAllocator(), true);
Storage storage(
Storage::use_byte_size_t(),
result.dtype(),
0,
at::cuda::getCUDADeviceAllocator(),
true);
return result.set_(storage, 0, {0}, {});
}

Expand Down
6 changes: 4 additions & 2 deletions aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp
Expand Up @@ -109,10 +109,12 @@ Tensor MakeStridedQTensorCPU(
TORCH_CHECK(
isQIntType(typeMetaToScalarType(dtype)),
"ScalarType is not supported in new_qtensor_cpu.");
int64_t size_bytes = nelements * dtype.itemsize();
auto storage = c10::make_intrusive<StorageImpl>(
StorageImpl::use_byte_size_t(),
dtype,
nelements,
allocator->allocate(nelements * dtype.itemsize()),
size_bytes,
allocator->allocate(size_bytes),
allocator,
/* resizable = */ true);
auto tensor = detail::make_tensor<QTensorImpl>(
Expand Down
21 changes: 11 additions & 10 deletions aten/src/ATen/native/xnnpack/Factory.cpp
Expand Up @@ -16,17 +16,18 @@ Tensor empty_with_tail_padding(
const DimnameList maybe_names) {
auto* const allocator_ptr = c10::GetDefaultMobileCPUAllocator();
const int64_t nelements = prod_intlist(size);
size_t size_bytes = nelements * dtype.itemsize();

Tensor tensor(
c10::make_intrusive<c10::TensorImpl>(
c10::Storage{
dtype,
nelements,
allocator_ptr->allocate(nelements * dtype.itemsize()),
allocator_ptr,
/*resizable=*/true,
},
DispatchKeySet{DispatchKey::CPU}));
Tensor tensor(c10::make_intrusive<c10::TensorImpl>(
c10::Storage{
c10::Storage::use_byte_size_t(),
dtype,
size_bytes,
allocator_ptr->allocate(size_bytes),
allocator_ptr,
/*resizable=*/true,
},
DispatchKeySet{DispatchKey::CPU}));

return namedinference::propagate_names_if_nonempty(
tensor.resize_(size, memory_format),
Expand Down
6 changes: 4 additions & 2 deletions aten/src/ATen/quantized/Quantizer.cpp
Expand Up @@ -81,10 +81,12 @@ inline Tensor new_qtensor(
TORCH_CHECK(
isQIntType(typeMetaToScalarType(dtype)),
"ScalarType is not supported in new_qtensor.");
int64_t size_bytes = nelements * dtype.itemsize();
auto storage = c10::make_intrusive<StorageImpl>(
StorageImpl::use_byte_size_t(),
dtype,
nelements,
allocator->allocate(nelements * dtype.itemsize()),
size_bytes,
allocator->allocate(size_bytes),
allocator,
/*resizable=*/true);
auto tensor = detail::make_tensor<QTensorImpl>(
Expand Down
9 changes: 5 additions & 4 deletions aten/src/ATen/templates/Functions.h
Expand Up @@ -36,10 +36,10 @@ inline Tensor from_blob(
" does not match device of data ", device);
}
auto storage = Storage(
Storage::use_byte_size_t(),
options.dtype(),
detail::computeStorageSize(sizes, strides),
InefficientStdFunctionContext::makeDataPtr(
data, deleter, device),
detail::computeStorageNbytes(sizes, strides, options.dtype().itemsize()),
InefficientStdFunctionContext::makeDataPtr(data, deleter, device),
/*allocator=*/nullptr,
/*resizable=*/false);
return empty({0}, options).set_(storage, 0, sizes, strides);
Expand Down Expand Up @@ -67,8 +67,9 @@ inline Tensor from_blob(
" does not match device of data ", device);
}
auto storage = Storage(
Storage::use_byte_size_t(),
options.dtype(),
detail::computeStorageSize(sizes, strides),
detail::computeStorageNbytes(sizes, strides, options.dtype().itemsize()),
DataPtr(data, nullptr, [](void*) {}, device),
/*allocator=*/nullptr,
/*resizable=*/false);
Expand Down
7 changes: 6 additions & 1 deletion aten/src/ATen/test/extension_backend_test.cpp
Expand Up @@ -14,7 +14,12 @@ Tensor empty_override(IntArrayRef size, const TensorOptions & options, c10::opti
test_int = 1;
auto tensor_impl = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(
Storage(
caffe2::TypeMeta::Make<float>(), 0, at::DataPtr(nullptr, Device(DeviceType::MSNPU, 1)), nullptr, false),
Storage::use_byte_size_t(),
caffe2::TypeMeta::Make<float>(),
0,
at::DataPtr(nullptr, Device(DeviceType::MSNPU, 1)),
nullptr,
false),
DispatchKey::MSNPU);
return Tensor(std::move(tensor_impl));
}
Expand Down

0 comments on commit 3706803

Please sign in to comment.