Skip to content

Commit

Permalink
from review: use as_strided() to create a view
Browse files Browse the repository at this point in the history
  • Loading branch information
mattip committed Aug 23, 2022
1 parent 1a00f37 commit d117e02
Showing 1 changed file with 15 additions and 17 deletions.
32 changes: 15 additions & 17 deletions aten/src/ATen/DLConvertor.cpp
Expand Up @@ -209,43 +209,41 @@ struct ATenDLMTensor {
};

void deleter(DLManagedTensor* arg) {
delete [] arg->dl_tensor.strides;
delete static_cast<ATenDLMTensor*>(arg->manager_ctx);
}

// This function returns a shared_ptr to memory managed DLpack tensor
// constructed out of ATen tensor
DLManagedTensor* toDLPack(const Tensor& src) {
// create a new tensor with possibly normalized strides
// gh-83069
auto shape = src.sizes();
auto strides = src.strides().vec();
for (int i=0; i<src.dim(); i++) {
if (shape[i] < 2) {
strides[i] = 1;
}
}

auto view = src.as_strided(shape, strides, src.storage_offset());
ATenDLMTensor* atDLMTensor(new ATenDLMTensor);
atDLMTensor->handle = src;
atDLMTensor->handle = view;
atDLMTensor->tensor.manager_ctx = atDLMTensor;
atDLMTensor->tensor.deleter = &deleter;
atDLMTensor->tensor.dl_tensor.data = src.data_ptr();
atDLMTensor->tensor.dl_tensor.data = view.data_ptr();
int64_t device_id = 0;
if (src.is_cuda()) {
device_id = src.get_device();
}
atDLMTensor->tensor.dl_tensor.device = getDLDevice(src, device_id);
atDLMTensor->tensor.dl_tensor.ndim = src.dim();
atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src);
// Normalize the strides to 1 wherever shape < 2
// gh-83069
auto shape = src.sizes().data();
int64_t *strides = new int64_t[src.dim()];
for (int i=0; i<src.dim(); i++) {
if (shape[i] < 2) {
strides[i] = 1;
}
else {
strides[i] = src.strides()[i];
}
}
atDLMTensor->tensor.dl_tensor.shape =
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<int64_t*>(shape);
const_cast<int64_t*>(view.sizes().data());
atDLMTensor->tensor.dl_tensor.strides =
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<int64_t*>(strides);
const_cast<int64_t*>(view.strides().data());
atDLMTensor->tensor.dl_tensor.byte_offset = 0;
return &(atDLMTensor->tensor);
}
Expand Down

0 comments on commit d117e02

Please sign in to comment.