Skip to content

Commit

Permalink
Use variable_data() in tensor_to_numpy (#22214)
Browse files Browse the repository at this point in the history
Summary:
As part of the Variable/Tensor merge, we want to gradually remove call sites of `tensor_data()` and the API itself, and instead uses `variable_data()`. This PR removes the `tensor_data()` call in the tensor_to_numpy conversion path.
Pull Request resolved: #22214

Differential Revision: D15997397

Pulled By: yf225

fbshipit-source-id: 6fcab7b14e138824fc2adb5434512bcf868ca375
  • Loading branch information
Will Feng authored and facebook-github-bot committed Jun 26, 2019
1 parent f176950 commit 5f84f37
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
7 changes: 1 addition & 6 deletions tools/autograd/templates/python_variable_methods.cpp
Expand Up @@ -428,12 +428,7 @@ static PyObject * THPVariable_numpy(PyObject* self, PyObject* arg)
HANDLE_TH_ERRORS
jit::tracer::warn("Converting a tensor to a NumPy array", jit::tracer::WARN_PYTHON_DATAFLOW);
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
if (self_.requires_grad()) {
throw std::runtime_error(
"Can't call numpy() on Variable that requires grad. "
"Use var.detach().numpy() instead.");
}
return torch::utils::tensor_to_numpy(self_.tensor_data());
return torch::utils::tensor_to_numpy(self_);
END_HANDLE_TH_ERRORS
}

Expand Down
8 changes: 7 additions & 1 deletion torch/csrc/utils/tensor_numpy.cpp
Expand Up @@ -85,6 +85,11 @@ PyObject* tensor_to_numpy(const at::Tensor& tensor) {
if (tensor.type().backend() != Backend::CPU) {
throw TypeError("NumPy conversion for %s is not supported", tensor.type().toString().c_str());
}
if (tensor.requires_grad()) {
throw std::runtime_error(
"Can't call numpy() on Variable that requires grad. "
"Use var.detach().numpy() instead.");
}
auto dtype = aten_to_numpy_dtype(tensor.scalar_type());
auto sizes = to_numpy_shape(tensor.sizes());
auto strides = to_numpy_shape(tensor.strides());
Expand All @@ -110,7 +115,8 @@ PyObject* tensor_to_numpy(const at::Tensor& tensor) {
// object of the ndarray to the tensor and disabling resizes on the storage.
// This is not sufficient. For example, the tensor's storage may be changed
// via Tensor.set_, which can free the underlying memory.
PyObject* py_tensor = THPVariable_Wrap(make_variable(tensor, false));
TORCH_INTERNAL_ASSERT(tensor.is_variable());
PyObject* py_tensor = THPVariable_Wrap(tensor);
if (!py_tensor) throw python_error();
if (PyArray_SetBaseObject((PyArrayObject*)array.get(), py_tensor) == -1) {
return nullptr;
Expand Down

0 comments on commit 5f84f37

Please sign in to comment.