Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide a few default args for numpy translation #20451

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 12 additions & 0 deletions test/test_torch.py
Expand Up @@ -1447,6 +1447,18 @@ def test_mv(self):

self.assertEqual(res1, res2)

def test_numpy_args(self):
x1 = torch.randn(10)
x2 = torch.randn(10)
res1 = torch.add(input=x1, other=x2)
res2 = torch.add(x1=x1, x2=x2)
self.assertEqual(res1, res2)

x1 = torch.randn(10, 10, 10)
res1 = x1.sum(dim=(0, 2), keepdim=True)
res2 = x1.sum(axis=(0, 2), keepdims=True)
self.assertEqual(res1, res2)

def test_add(self):
# [res] torch.add([res,] tensor1, tensor2)
m1 = torch.randn(100, 100)
Expand Down
40 changes: 35 additions & 5 deletions torch/csrc/utils/python_arg_parser.cpp
Expand Up @@ -32,6 +32,28 @@ static std::unordered_map<std::string, ParameterType> type_map = {
{"std::string", ParameterType::STRING},
};

// Default arg name translations for compatibility with NumPy.
//
// Example:
// ```python
// t = torch.randn(10,10)
// torch.sum(a=t, axis=0, keepdim=True)
// ```
//
// A vector is necessary, because we might need to try multiple values.
// In particular, NumPy sometimes uses "x" and sometimes "a" for the main input tensor.
// Rather than annotate each function separately with whether it should take "x" or "a",
// just try both.
//
// TODO: Allow individual functions to specify non-default translations:
// For example, `torch.pow` should translate "exponent" to "x2".
static const std::unordered_map<std::string, std::vector<std::string>> numpy_compatibility_arg_names = {
{"dim", {"axis"}},
{"keepdim", {"keepdims"}},
{"input", {"x", "a", "x1"}},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit ambivalent about the input and other aliases. I'm not sure people use them as kwargs often enough to make it worthwhile. There's also a long tail of names that NumPy uses instead of "input", including "A", "arr", "ary", "m", "a1", "array".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit ambivalent...

I don't know how often people use them as kwargs. But I want to be as complete as possible.

There's also a long tail...

Yes, and I've also seen p, and who knows what else there is. I think we need to decide case-by-case, depending on how common they are, whether to include them in the default common args here, or annotate individual functions with alternate arg names (I will introduce a facility for doing that in a future PR).

{"other", {"x2"}},
};

// TODO: remove this. This is a temporary list of functions that allow Python
// numbers to bind to Tensors. Some binary ops have separate Tensor and Scalar
// overloads and binding to the Tensor overload with a number of a different
Expand Down Expand Up @@ -94,11 +116,13 @@ FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only)
} else {
name = name_str;
}
#if PY_MAJOR_VERSION == 2
python_name = PyString_InternFromString(name.c_str());
#else
python_name = PyUnicode_InternFromString(name.c_str());
#endif
python_name = THPUtils_internString(name);
auto np_compat_it = numpy_compatibility_arg_names.find(name);
if (np_compat_it != numpy_compatibility_arg_names.end()) {
for (const auto& str: np_compat_it->second) {
numpy_python_names.push_back(THPUtils_internString(str));
}
}
}

bool FunctionParameter::check(PyObject* obj) {
Expand Down Expand Up @@ -461,6 +485,12 @@ bool FunctionSignature::parse(PyObject* args, PyObject* kwargs, PyObject* dst[],
obj = PyTuple_GET_ITEM(args, arg_pos);
} else if (kwargs) {
obj = PyDict_GetItem(kwargs, param.python_name);
for (PyObject *numpy_name: param.numpy_python_names) {
if (obj) {
break;
}
obj = PyDict_GetItem(kwargs, numpy_name);
}
is_kwd = true;
}

Expand Down
1 change: 1 addition & 0 deletions torch/csrc/utils/python_arg_parser.h
Expand Up @@ -177,6 +177,7 @@ struct FunctionParameter {
// having this as a raw PyObject * will presumably leak it, but these are only held by static objects
// anyway, and Py_Finalize can already be called when this is destructed.
PyObject *python_name;
at::SmallVector<PyObject *, 5> numpy_python_names;
at::Scalar default_scalar;
std::vector<int64_t> default_intlist;
union {
Expand Down
8 changes: 8 additions & 0 deletions torch/csrc/utils/python_strings.h
Expand Up @@ -57,3 +57,11 @@ inline PyObject* THPUtils_packString(const std::string& str) {
return PyUnicode_FromStringAndSize(str.c_str(), str.size());
#endif
}

inline PyObject* THPUtils_internString(const std::string& str) {
#if PY_MAJOR_VERSION == 2
return PyString_InternFromString(str.c_str());
#else
return PyUnicode_InternFromString(str.c_str());
#endif
}