Skip to content

Commit

Permalink
Expose arbitrary cpp autograd functions to Python (#11082)
Browse files Browse the repository at this point in the history
Summary:
This is needed because the JIT declares some custom autograd functions.

colesbury
Pull Request resolved: #11082

Differential Revision: D9580456

Pulled By: apaszke

fbshipit-source-id: 6bf00c1188a20b2ee6ecf60e5a0099f8263ad55a
  • Loading branch information
apaszke authored and facebook-github-bot committed Aug 30, 2018
1 parent 93bd291 commit f0142fa
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 5 deletions.
5 changes: 4 additions & 1 deletion torch/_tensor_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,10 @@ def _str(self):
tensor_str = _tensor_str(self, indent, formatter, summarize)

if self.grad_fn is not None:
suffix += ', grad_fn=<{}>'.format(type(self.grad_fn).__name__)
name = type(self.grad_fn).__name__
if name == 'CppFunction':
name = self.grad_fn.name().rsplit('::', maxsplit=1)[-1]
suffix += ', grad_fn=<{}>'.format(name)
elif self.requires_grad:
suffix += ', requires_grad=True'

Expand Down
23 changes: 20 additions & 3 deletions torch/csrc/autograd/python_cpp_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "torch/csrc/autograd/python_hook.h"
#include "torch/csrc/autograd/python_anomaly_mode.h"
#include "torch/csrc/utils/auto_gil.h"
#include "torch/csrc/utils/python_strings.h"
#include "torch/csrc/DynamicTypes.h"
#include "torch/csrc/Exceptions.h"

Expand Down Expand Up @@ -152,6 +153,10 @@ PyObject* THPCppFunction_register_hook(PyObject* self, PyObject* hook)
return registerFunctionHook(fn, hook);
}

PyObject* THPCppFunction_name(PyObject* self) {
auto& fn = *((THPCppFunction*)self)->cdata;
return THPUtils_packString(fn.name());
}

static struct PyMethodDef default_methods[] = {
THP_FUNCTION_DEFAULT_METHODS,
Expand Down Expand Up @@ -184,8 +189,19 @@ PyTypeObject* _initFunctionPyTypeObject(PyTypeObject& type, const char* name,

static std::unordered_map<std::type_index, THPObjectPtr> cpp_function_types;

struct DefaultFunctionType {
DefaultFunctionType() {
_initFunctionPyTypeObject(type, "CppFunction", nullptr, nullptr);
Py_INCREF(&type);
}

PyTypeObject type;
};

PyObject* functionToPyObject(std::shared_ptr<Function> cdata)
{
static DefaultFunctionType default_type;

if (!cdata) {
Py_RETURN_NONE;
}
Expand All @@ -201,12 +217,13 @@ PyObject* functionToPyObject(std::shared_ptr<Function> cdata)
} else {
auto& fn = *cdata;
auto it = cpp_function_types.find(std::type_index(typeid(fn)));
PyTypeObject* type;
if (it == cpp_function_types.end()) {
return PyErr_Format(PyExc_TypeError,
"Don't know how to create Python object for %s", typeid(fn).name());
type = &default_type.type;
} else {
type = (PyTypeObject*)it->second.get();
}

PyTypeObject* type = (PyTypeObject*)it->second.get();
THPObjectPtr obj(type->tp_alloc(type, 0));
if (!obj) return nullptr;
THPCppFunction* f = (THPCppFunction*)obj.get();
Expand Down
4 changes: 3 additions & 1 deletion torch/csrc/autograd/python_cpp_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ PyObject* CppFunction_pynew(PyTypeObject *type, PyObject *args, PyObject *kwds)

#define THP_FUNCTION_DEFAULT_METHODS \
{(char*)"_register_hook_dict", (PyCFunction)THPCppFunction_register_hook_dict, METH_O, nullptr}, \
{(char*)"register_hook", (PyCFunction)THPCppFunction_register_hook, METH_O, nullptr}
{(char*)"register_hook", (PyCFunction)THPCppFunction_register_hook, METH_O, nullptr}, \
{(char*)"name", (PyCFunction)THPCppFunction_name, METH_NOARGS, nullptr}

#define THP_FUNCTION_DEFAULT_PROPERTIES \
{(char*)"next_functions", (getter)THPCppFunction_next_functions, nullptr, nullptr, nullptr}, \
Expand All @@ -44,6 +45,7 @@ PyObject* THPCppFunction_metadata(THPCppFunction *self, void *_unused);
PyObject* THPCppFunction_requires_grad(THPCppFunction* self);
PyObject* THPCppFunction_register_hook_dict(PyObject* self, PyObject* _var);
PyObject* THPCppFunction_register_hook(PyObject* self, PyObject* hook);
PyObject* THPCppFunction_name(PyObject* self);

PyTypeObject* _initFunctionPyTypeObject(PyTypeObject& type, const char* name,
PyGetSetDef* function_properties, PyMethodDef* function_methods);
Expand Down

0 comments on commit f0142fa

Please sign in to comment.