diff --git a/test/test_torch.py b/test/test_torch.py index 6eb773c54d13..3e51f7bdd2b4 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -1447,6 +1447,18 @@ def test_mv(self): self.assertEqual(res1, res2) + def test_numpy_args(self): + x1 = torch.randn(10) + x2 = torch.randn(10) + res1 = torch.add(input=x1, other=x2) + res2 = torch.add(x1=x1, x2=x2) + self.assertEqual(res1, res2) + + x1 = torch.randn(10, 10, 10) + res1 = x1.sum(dim=(0, 2), keepdim=True) + res2 = x1.sum(axis=(0, 2), keepdims=True) + self.assertEqual(res1, res2) + def test_add(self): # [res] torch.add([res,] tensor1, tensor2) m1 = torch.randn(100, 100) diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index 63e462542fe1..b85ee1a6c91d 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -32,6 +32,28 @@ static std::unordered_map type_map = { {"std::string", ParameterType::STRING}, }; +// Default arg name translations for compatibility with NumPy. +// +// Example: +// ```python +// t = torch.randn(10,10) +// torch.sum(a=t, axis=0, keepdim=True) +// ``` +// +// A vector is necessary, because we might need to try multiple values. +// In particular, NumPy sometimes uses "x" and sometimes "a" for the main input tensor. +// Rather than annotate each function separately with whether it should take "x" or "a", +// just try both. +// +// TODO: Allow individual functions to specify non-default translations: +// For example, `torch.pow` should translate "exponent" to "x2". +static const std::unordered_map> numpy_compatibility_arg_names = { + {"dim", {"axis"}}, + {"keepdim", {"keepdims"}}, + {"input", {"x", "a", "x1"}}, + {"other", {"x2"}}, +}; + // TODO: remove this. This is a temporary list of functions that allow Python // numbers to bind to Tensors. Some binary ops have separate Tensor and Scalar // overloads and binding to the Tensor overload with a number of a different @@ -94,11 +116,13 @@ FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only) } else { name = name_str; } -#if PY_MAJOR_VERSION == 2 - python_name = PyString_InternFromString(name.c_str()); -#else - python_name = PyUnicode_InternFromString(name.c_str()); -#endif + python_name = THPUtils_internString(name); + auto np_compat_it = numpy_compatibility_arg_names.find(name); + if (np_compat_it != numpy_compatibility_arg_names.end()) { + for (const auto& str: np_compat_it->second) { + numpy_python_names.push_back(THPUtils_internString(str)); + } + } } bool FunctionParameter::check(PyObject* obj) { @@ -461,6 +485,12 @@ bool FunctionSignature::parse(PyObject* args, PyObject* kwargs, PyObject* dst[], obj = PyTuple_GET_ITEM(args, arg_pos); } else if (kwargs) { obj = PyDict_GetItem(kwargs, param.python_name); + for (PyObject *numpy_name: param.numpy_python_names) { + if (obj) { + break; + } + obj = PyDict_GetItem(kwargs, numpy_name); + } is_kwd = true; } diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index 8b0b1d3c74f7..a8ebb24024ae 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -177,6 +177,7 @@ struct FunctionParameter { // having this as a raw PyObject * will presumably leak it, but these are only held by static objects // anyway, and Py_Finalize can already be called when this is destructed. PyObject *python_name; + at::SmallVector numpy_python_names; at::Scalar default_scalar; std::vector default_intlist; union { diff --git a/torch/csrc/utils/python_strings.h b/torch/csrc/utils/python_strings.h index 80dce9f7ee40..56f7b1cba3ba 100644 --- a/torch/csrc/utils/python_strings.h +++ b/torch/csrc/utils/python_strings.h @@ -57,3 +57,11 @@ inline PyObject* THPUtils_packString(const std::string& str) { return PyUnicode_FromStringAndSize(str.c_str(), str.size()); #endif } + +inline PyObject* THPUtils_internString(const std::string& str) { +#if PY_MAJOR_VERSION == 2 + return PyString_InternFromString(str.c_str()); +#else + return PyUnicode_InternFromString(str.c_str()); +#endif +}