Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an option to getWriteableTensorData to avoid copy CUDA tensor to CPU #46524

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 3 additions & 2 deletions torch/csrc/jit/serialization/pickler.cpp
Expand Up @@ -596,12 +596,13 @@ void Pickler::pushTuple(const IValue& ivalue) {
}
}

WriteableTensorData getWriteableTensorData(const at::Tensor& tensor) {
WriteableTensorData getWriteableTensorData(
const at::Tensor& tensor, bool toCpu) {
WriteableTensorData result;
result.tensor_ = tensor;
result.size_ = tensor.storage().nbytes();
// TODO HIP support
if (tensor.storage().device_type() == DeviceType::CUDA) {
if (tensor.storage().device_type() == DeviceType::CUDA && toCpu) {
// NB: This new tensor is created to support cuda tensors.
// Storages can be mutated when converting tensors from cuda to cpu,
// and we need a cpu tensor to copy data from.
Expand Down
5 changes: 3 additions & 2 deletions torch/csrc/jit/serialization/pickler.h
Expand Up @@ -107,7 +107,7 @@ struct WriteableTensorData {

private:
friend TORCH_API WriteableTensorData
getWriteableTensorData(const at::Tensor& tensor);
getWriteableTensorData(const at::Tensor& tensor, bool toCpu);
at::Tensor tensor_;
uint64_t size_;
};
Expand Down Expand Up @@ -266,7 +266,8 @@ class TORCH_API Pickler {

// returns a (tensor, record_size) for a tensor, converting it to a CPU tensor
// if necessary
TORCH_API WriteableTensorData getWriteableTensorData(const at::Tensor& tensor);
TORCH_API WriteableTensorData getWriteableTensorData(
const at::Tensor& tensor, bool toCpu=true);

// return the value of the tensor's storage pointer
uint64_t getStorageKey(const at::Tensor& tensor);
Expand Down