Skip to content

Commit

Permalink
Make NumPy dependency dynamic (pytorch#52794)
Browse files Browse the repository at this point in the history
Summary:
Move NumPy initialization from `initModule()` to singleton inside
`torch::utils::is_numpy_available()` function.
This singleton will print a warning, that NumPy integration is not
available, rather than fails to import torch altogether.
The warning be printed only once, and will look something like the
following:
```
UserWarning: Failed to initialize NumPy: No module named 'numpy.core' (Triggered internally at  ../torch/csrc/utils/tensor_numpy.cpp:66.)
```

This is helpful if PyTorch was compiled with wrong NumPy version, of
NumPy is not commonly available on the platform (which is often the case
on AARCH64 or Apple M1)

Test that PyTorch is usable after numpy is uninstalled at the end of
`_test1` CI config.

Pull Request resolved: pytorch#52794

Reviewed By: seemethere

Differential Revision: D26650509

Pulled By: malfet

fbshipit-source-id: a2d98769ef873862c3704be4afda075d76d3ad06
  • Loading branch information
malfet authored and Sacha Refshauge committed Mar 31, 2021
1 parent 805b750 commit 79c5bfc
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 21 deletions.
1 change: 1 addition & 0 deletions .jenkins/pytorch/fake_numpy/numpy.py
@@ -0,0 +1 @@
raise ModuleNotFoundError("Sorry PyTorch, but our NumPy is in the other folder")
9 changes: 8 additions & 1 deletion .jenkins/pytorch/test.sh
Expand Up @@ -144,6 +144,12 @@ test_aten() {
fi
}

test_without_numpy() {
pushd "$(dirname "${BASH_SOURCE[0]}")"
python -c "import sys;sys.path.insert(0, 'fake_numpy');from unittest import TestCase;import torch;x=torch.randn(3,3);TestCase().assertRaises(RuntimeError, lambda: x.numpy())"
popd
}

# pytorch extensions require including torch/extension.h which includes all.h
# which includes utils.h which includes Parallel.h.
# So you can call for instance parallel_for() from your extension,
Expand Down Expand Up @@ -386,12 +392,13 @@ elif [[ "${BUILD_ENVIRONMENT}" == *-test1 || "${JOB_BASE_NAME}" == *-test1 ]]; t
if [[ "${BUILD_ENVIRONMENT}" == pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-test1 ]]; then
test_torch_deploy
fi
test_without_numpy
install_torchvision
test_python_shard1
test_aten
elif [[ "${BUILD_ENVIRONMENT}" == *-test2 || "${JOB_BASE_NAME}" == *-test2 ]]; then
install_torchvision
test_python_shard2
test_aten
test_libtorch
test_custom_script_ops
test_custom_backend
Expand Down
3 changes: 0 additions & 3 deletions setup.py
Expand Up @@ -694,9 +694,6 @@ def configure_extension_build():
library_dirs.append(
os.path.dirname(cmake_cache_vars['CUDA_CUDA_LIB']))

if cmake_cache_vars['USE_NUMPY']:
extra_install_requires += ['numpy']

if build_type.is_debug():
if IS_WINDOWS:
extra_compile_args.append('/Z7')
Expand Down
6 changes: 0 additions & 6 deletions torch/csrc/Module.cpp
Expand Up @@ -71,9 +71,6 @@
#include <callgrind.h>
#endif

#define WITH_NUMPY_IMPORT_ARRAY
#include <torch/csrc/utils/numpy_stub.h>

namespace py = pybind11;

PyObject* module;
Expand Down Expand Up @@ -1004,9 +1001,6 @@ Call this whenever a new thread is created in order to propagate values from
ASSERT_TRUE(set_module_attr("DisableTorchFunction", (PyObject*)THPModule_DisableTorchFunctionType(), /* incref= */ false));
torch::set_disabled_torch_function_impl(PyObject_GetAttrString(module, "_disabled_torch_function_impl"));
ASSERT_TRUE(torch::disabled_torch_function_impl() != nullptr);
#ifdef USE_NUMPY
if (_import_array() < 0) return nullptr;
#endif
return module;
END_HANDLE_TH_ERRORS
}
Expand Down
1 change: 0 additions & 1 deletion torch/csrc/utils/python_arg_parser.h
Expand Up @@ -56,7 +56,6 @@
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/python_dimname.h>
#include <torch/csrc/tensor/python_tensor.h>
#include <torch/csrc/utils/numpy_stub.h>
#include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_numbers.h>
Expand Down
16 changes: 9 additions & 7 deletions torch/csrc/utils/tensor_new.cpp
Expand Up @@ -152,12 +152,14 @@ std::vector<int64_t> compute_sizes(PyObject* seq) {

ScalarType infer_scalar_type(PyObject *obj) {
#ifdef USE_NUMPY
if (PyArray_Check(obj)) {
return numpy_dtype_to_aten(PyArray_TYPE((PyArrayObject*)obj));
}
if (PyArray_CheckScalar(obj)) {
THPObjectPtr arr(PyArray_FromScalar(obj, nullptr));
return numpy_dtype_to_aten(PyArray_TYPE((PyArrayObject*) arr.get()));
if (is_numpy_available()) {
if (PyArray_Check(obj)) {
return numpy_dtype_to_aten(PyArray_TYPE((PyArrayObject*)obj));
}
if (PyArray_CheckScalar(obj)) {
THPObjectPtr arr(PyArray_FromScalar(obj, nullptr));
return numpy_dtype_to_aten(PyArray_TYPE((PyArrayObject*) arr.get()));
}
}
#endif
if (PyFloat_Check(obj)) {
Expand Down Expand Up @@ -273,7 +275,7 @@ Tensor internal_new_from_data(
return tensor.to(device, inferred_scalar_type, /*non_blocking=*/false, /*copy=*/copy_numpy);
}

if (PyArray_Check(data)) {
if (is_numpy_available() && PyArray_Check(data)) {
TORCH_CHECK(!pin_memory, "Can't pin tensor constructed from numpy");
auto tensor = tensor_from_numpy(data, /*warn_if_not_writeable=*/!copy_numpy);
const auto& inferred_scalar_type = type_inference ? tensor.scalar_type() : scalar_type;
Expand Down
45 changes: 42 additions & 3 deletions torch/csrc/utils/tensor_numpy.cpp
@@ -1,5 +1,6 @@
#include <torch/csrc/THP.h>
#include <torch/csrc/utils/tensor_numpy.h>
#define WITH_NUMPY_IMPORT_ARRAY
#include <torch/csrc/utils/numpy_stub.h>

#ifndef USE_NUMPY
Expand All @@ -10,6 +11,11 @@ PyObject* tensor_to_numpy(const at::Tensor& tensor) {
at::Tensor tensor_from_numpy(PyObject* obj, bool warn_if_not_writeable/*=true*/) {
throw std::runtime_error("PyTorch was compiled without NumPy support");
}

bool is_numpy_available() {
throw std::runtime_error("PyTorch was compiled without NumPy support");
}

bool is_numpy_int(PyObject* obj) {
throw std::runtime_error("PyTorch was compiled without NumPy support");
}
Expand Down Expand Up @@ -38,6 +44,30 @@ using namespace torch::autograd;

namespace torch { namespace utils {

bool is_numpy_available() {
static bool available = []() {
if (_import_array() >= 0) {
return true;
}
// Try to get exception message, print warning and return false
std::string message = "Failed to initialize NumPy";
PyObject *type, *value, *traceback;
PyErr_Fetch(&type, &value, &traceback);
if (auto str = value ? PyObject_Str(value) : nullptr) {
if (auto enc_str = PyUnicode_AsEncodedString(str, "utf-8", "strict")) {
if (auto byte_str = PyBytes_AS_STRING(enc_str)) {
message += ": " + std::string(byte_str);
}
Py_XDECREF(enc_str);
}
Py_XDECREF(str);
}
PyErr_Clear();
TORCH_WARN(message);
return false;
}();
return available;
}
static std::vector<npy_intp> to_numpy_shape(IntArrayRef x) {
// shape and stride conversion from int64_t to npy_intp
auto nelem = x.size();
Expand Down Expand Up @@ -74,6 +104,9 @@ static std::vector<int64_t> seq_to_aten_shape(PyObject *py_seq) {
}

PyObject* tensor_to_numpy(const at::Tensor& tensor) {
if (!is_numpy_available()) {
throw std::runtime_error("Numpy is not available");
}
if (tensor.device().type() != DeviceType::CPU) {
throw TypeError(
"can't convert %s device type tensor to numpy. Use Tensor.cpu() to "
Expand Down Expand Up @@ -126,6 +159,9 @@ PyObject* tensor_to_numpy(const at::Tensor& tensor) {
}

at::Tensor tensor_from_numpy(PyObject* obj, bool warn_if_not_writeable/*=true*/) {
if (!is_numpy_available()) {
throw std::runtime_error("Numpy is not available");
}
if (!PyArray_Check(obj)) {
throw TypeError("expected np.ndarray (got %s)", Py_TYPE(obj)->tp_name);
}
Expand Down Expand Up @@ -245,15 +281,18 @@ ScalarType numpy_dtype_to_aten(int dtype) {
}

bool is_numpy_int(PyObject* obj) {
return PyArray_IsScalar((obj), Integer);
return is_numpy_available() && PyArray_IsScalar((obj), Integer);
}

bool is_numpy_scalar(PyObject* obj) {
return is_numpy_int(obj) || PyArray_IsScalar(obj, Bool) ||
PyArray_IsScalar(obj, Floating) || PyArray_IsScalar(obj, ComplexFloating);
return is_numpy_available() && (is_numpy_int(obj) || PyArray_IsScalar(obj, Bool) ||
PyArray_IsScalar(obj, Floating) || PyArray_IsScalar(obj, ComplexFloating));
}

at::Tensor tensor_from_cuda_array_interface(PyObject* obj) {
if (!is_numpy_available()) {
throw std::runtime_error("Numpy is not available");
}
auto cuda_dict = THPObjectPtr(PyObject_GetAttrString(obj, "__cuda_array_interface__"));
TORCH_INTERNAL_ASSERT(cuda_dict);

Expand Down
1 change: 1 addition & 0 deletions torch/csrc/utils/tensor_numpy.h
Expand Up @@ -11,6 +11,7 @@ at::Tensor tensor_from_numpy(PyObject* obj, bool warn_if_not_writeable=true);
int aten_to_numpy_dtype(const at::ScalarType scalar_type);
at::ScalarType numpy_dtype_to_aten(int dtype);

bool is_numpy_available();
bool is_numpy_int(PyObject* obj);
bool is_numpy_scalar(PyObject* obj);

Expand Down

0 comments on commit 79c5bfc

Please sign in to comment.