diff --git a/aten/src/ATen/core/builtin_function.h b/aten/src/ATen/core/builtin_function.h index b4804cfebcbe..3d7f70d86877 100644 --- a/aten/src/ATen/core/builtin_function.h +++ b/aten/src/ATen/core/builtin_function.h @@ -10,13 +10,19 @@ struct BuiltinOpFunction : public Function { BuiltinOpFunction( c10::QualifiedName qualname, c10::FunctionSchema schema, - std::function callable) + std::function callable, + std::string doc_string = "") : name_(std::move(qualname)), callable_(std::move(callable)), - schema_(std::move(schema)) { + schema_(std::move(schema)), + doc_string_(std::move(doc_string)) { TORCH_INTERNAL_ASSERT(schema_.returns().size() == 1); } + const std::string& doc_string() const override { + return doc_string_; + } + bool isGraphFunction() const override { return false; } @@ -110,6 +116,8 @@ struct BuiltinOpFunction : public Function { std::function callable_; c10::FunctionSchema schema_; + + std::string doc_string_; }; } // namespace jit diff --git a/aten/src/ATen/core/function.h b/aten/src/ATen/core/function.h index 0cf658b0f701..8264bc57e8e8 100644 --- a/aten/src/ATen/core/function.h +++ b/aten/src/ATen/core/function.h @@ -25,6 +25,11 @@ TORCH_API void preoptimizeGraph(std::shared_ptr& graph); // execution of the function. Method is a wrapper around an // underlying Function that also provides a `self` object. struct TORCH_API Function { + virtual const std::string& doc_string() const { + static const std::string no_doc_string = ""; + return no_doc_string; + } + virtual bool isGraphFunction() const = 0; virtual void run(Stack& stack) = 0; diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index a1b21ee1ba21..06900064e266 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -1989,7 +1989,8 @@ struct CAFFE2_API ClassType : public NamedType { static ClassTypePtr create( c10::optional qualifiedName, std::weak_ptr cu, - bool is_module = false); + bool is_module = false, + std::string doc_string = ""); bool operator==(const Type& rhs) const override { if (auto user_rhs = rhs.cast()) { @@ -2175,6 +2176,9 @@ struct CAFFE2_API ClassType : public NamedType { return constantNames_[slot]; } + const std::string& doc_string() const { + return doc_string_; + } IValue getConstant(const std::string& name) const; @@ -2271,7 +2275,8 @@ struct CAFFE2_API ClassType : public NamedType { ClassType( c10::optional name, std::weak_ptr cu, - bool is_module); + bool is_module, + std::string doc_string); std::string annotation_str_impl(TypePrinter printer = nullptr) const override { const auto& n = name().value(); @@ -2306,6 +2311,9 @@ struct CAFFE2_API ClassType : public NamedType { std::vector properties_; bool isModule_ = false; + + // Doc string of class. + std::string doc_string_ = ""; }; struct InterfaceType; diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index 634d706091af..f5478d040060 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -1211,19 +1211,21 @@ InterfaceType::~InterfaceType() = default; ClassTypePtr ClassType::create( c10::optional qualifiedName, std::weak_ptr cu, - bool is_module) { + bool is_module, + std::string doc_string) { return ClassTypePtr( - new ClassType(std::move(qualifiedName), std::move(cu), is_module)); + new ClassType(std::move(qualifiedName), std::move(cu), is_module, std::move(doc_string))); } ClassType::ClassType( c10::optional name, std::weak_ptr cu, - bool is_module = false) + bool is_module = false, + std::string doc_string = "") : NamedType(TypeKind::ClassType, std::move(name)), compilation_unit_(std::move(cu)), - isModule_(is_module) { -} + isModule_(is_module), + doc_string_(std::move(doc_string)) {} const std::vector& ClassType::methods() const { return methods_; diff --git a/test/cpp/jit/test_custom_class.cpp b/test/cpp/jit/test_custom_class.cpp index a96a3b4a5635..776df23e1737 100644 --- a/test/cpp/jit/test_custom_class.cpp +++ b/test/cpp/jit/test_custom_class.cpp @@ -44,5 +44,47 @@ TEST(CustomClassTest, TorchbindIValueAPI) { test_with_obj(new_stack_ivalue, "boo"); } +class TorchBindTestClass : public torch::jit::CustomClassHolder { + public: + std::string get() { + return "Hello, I am your test custom class"; + } +}; + +constexpr char class_doc_string[] = R"( + I am docstring for TorchBindTestClass + Args: + What is an argument? Oh never mind, I don't take any. + + Return: + How would I know? I am just a holder of some meaningless test methods. + )"; +constexpr char method_doc_string[] = + "I am docstring for TorchBindTestClass get_with_docstring method"; + +namespace { +static auto reg = + torch::class_( + "_TorchBindTest", + "_TorchBindTestClass", + class_doc_string) + .def("get", &TorchBindTestClass::get) + .def("get_with_docstring", &TorchBindTestClass::get, method_doc_string); + +} // namespace + +// Tests DocString is properly propagated when defining CustomClasses. +TEST(CustomClassTest, TestDocString) { + auto class_type = getCustomClass( + "__torch__.torch.classes._TorchBindTest._TorchBindTestClass"); + AT_ASSERT(class_type); + AT_ASSERT(class_type->doc_string() == class_doc_string); + + AT_ASSERT(class_type->getMethod("get").doc_string().empty()); + AT_ASSERT( + class_type->getMethod("get_with_docstring").doc_string() == + method_doc_string); +} + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/python/python_custom_class.cpp b/torch/csrc/jit/python/python_custom_class.cpp index 49c85c8c3c7f..9809b854e6ac 100644 --- a/torch/csrc/jit/python/python_custom_class.cpp +++ b/torch/csrc/jit/python/python_custom_class.cpp @@ -28,7 +28,10 @@ void initPythonCustomClassBindings(PyObject* module) { auto m = py::handle(module).cast(); py::class_(m, "ScriptClass") - .def("__call__", &ScriptClass::__call__); + .def("__call__", &ScriptClass::__call__) + .def_property_readonly("__doc__", [](const ScriptClass& self) { + return self.class_type_.type_->expect()->doc_string(); + }); // This function returns a ScriptClass that wraps the constructor // of the given class, specified by the qualified name passed in. diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index 74e0e75362a6..a99f7469ac65 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -1187,9 +1187,13 @@ void initJitScriptBindings(PyObject* module) { "name", [](const StrongFunctionPtr& self) { return self.function_->name(); }) .def_property_readonly( - "qualified_name", [](const StrongFunctionPtr& self) { + "qualified_name", + [](const StrongFunctionPtr& self) { return self.function_->qualname().qualifiedName(); - }); + }) + .def_property_readonly("__doc__", [](const StrongFunctionPtr& self) { + return self.function_->doc_string(); + }); py::class_(m, "ScriptMethod", py::dynamic_attr()) .def( diff --git a/torch/custom_class.h b/torch/custom_class.h index 3805cfafc91a..571a584294db 100644 --- a/torch/custom_class.h +++ b/torch/custom_class.h @@ -58,14 +58,16 @@ class class_ { /// see this class exposed as in Python and TorchScript. For example, if /// you pass `foo` as the namespace name and `Bar` as the className, the /// class will appear as `torch.classes.foo.Bar` in Python and TorchScript - explicit class_(const std::string& namespaceName, const std::string& className) { + explicit class_(const std::string& namespaceName, const std::string& className, std::string doc_string = "") { detail::checkValidIdent(namespaceName, "Namespace name"); detail::checkValidIdent(className, "Class name"); qualClassName = std::string("__torch__.torch.classes.") + namespaceName + "." + className; classTypePtr = at::ClassType::create( c10::QualifiedName(qualClassName), - std::weak_ptr()); + std::weak_ptr(), + /*is_module=*/false, + std::move(doc_string)); classTypePtr->addAttribute("capsule", at::CapsuleType::get()); c10::getCustomClassTypeMap().insert( @@ -81,7 +83,7 @@ class class_ { /// `torch::init()` would register a two-argument constructor /// taking an `int` and a `std::string` as argument. template - class_& def(detail::types) { // Used in combination with + class_& def(detail::types, std::string doc_string = "") { // Used in combination with // torch::init<...>() auto func = [](c10::tagged_capsule self, Types... args) { auto classObj = c10::make_intrusive(args...); @@ -89,7 +91,7 @@ class class_ { object->setSlot(0, c10::IValue::make_capsule(std::move(classObj))); }; - defineMethod("__init__", std::move(func)); + defineMethod("__init__", std::move(func), std::move(doc_string)); return *this; } @@ -112,18 +114,18 @@ class class_ { /// // do something /// }) template - class_& def(std::string name, Func f) { + class_& def(std::string name, Func f, std::string doc_string = "") { auto wrapped_f = detail::wrap_func(std::move(f)); - defineMethod(std::move(name), std::move(wrapped_f)); + defineMethod(std::move(name), std::move(wrapped_f), std::move(doc_string)); return *this; } /// This is an unsafe method registration API added for adding custom JIT backend support via custom /// C++ classes. It is not for general purpose use. - class_& _def_unboxed(std::string name, std::function func, c10::FunctionSchema schema) { + class_& _def_unboxed(std::string name, std::function func, c10::FunctionSchema schema, std::string doc_string = "") { auto qualMethodName = qualClassName + "." + name; auto method = std::make_unique( - qualMethodName, std::move(schema), std::move(func)); + qualMethodName, std::move(schema), std::move(func), std::move(doc_string)); classTypePtr->addMethod(method.get()); registerCustomClassMethod(std::move(method)); return *this; @@ -228,7 +230,7 @@ class class_ { private: template - void defineMethod(std::string name, Func func) { + void defineMethod(std::string name, Func func, std::string doc_string = "") { auto qualMethodName = qualClassName + "." + name; auto schema = c10::inferFunctionSchemaSingleReturn(std::move(name), ""); @@ -241,7 +243,7 @@ class class_ { detail::BoxedProxy()(stack, func); }; auto method = std::make_unique( - qualMethodName, std::move(schema), std::move(wrapped_func)); + qualMethodName, std::move(schema), std::move(wrapped_func), std::move(doc_string)); // Register the method here to keep the Method alive. // ClassTypes do not hold ownership of their methods (normally it