Skip to content

Commit

Permalink
Support doc_string for TorchBind custom classes (#46576)
Browse files Browse the repository at this point in the history
Summary:
With this PR, users can optionally provide a "doc_string" to describe a class or its method. doc_string for TorchBind classes and methods are stored as `doc_string` properties in `Function` and `ScriptClass`. These `dos_string` properties are then exposed in Python layer via PyBind for doc generation.

Fixes #46047

Pull Request resolved: #46576

Reviewed By: wanchaol

Differential Revision: D24440636

Pulled By: gmagogsfm

fbshipit-source-id: bfa9b270a6c2d8bc769a88fad6be939cc6310412
  • Loading branch information
gmagogsfm authored and facebook-github-bot committed Oct 24, 2020
1 parent 7d4c1a5 commit f9b9430
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 22 deletions.
12 changes: 10 additions & 2 deletions aten/src/ATen/core/builtin_function.h
Expand Up @@ -10,13 +10,19 @@ struct BuiltinOpFunction : public Function {
BuiltinOpFunction(
c10::QualifiedName qualname,
c10::FunctionSchema schema,
std::function<void(Stack&)> callable)
std::function<void(Stack&)> callable,
std::string doc_string = "")
: name_(std::move(qualname)),
callable_(std::move(callable)),
schema_(std::move(schema)) {
schema_(std::move(schema)),
doc_string_(std::move(doc_string)) {
TORCH_INTERNAL_ASSERT(schema_.returns().size() == 1);
}

const std::string& doc_string() const override {
return doc_string_;
}

bool isGraphFunction() const override {
return false;
}
Expand Down Expand Up @@ -110,6 +116,8 @@ struct BuiltinOpFunction : public Function {
std::function<void(Stack&)> callable_;

c10::FunctionSchema schema_;

std::string doc_string_;
};

} // namespace jit
Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/core/function.h
Expand Up @@ -25,6 +25,11 @@ TORCH_API void preoptimizeGraph(std::shared_ptr<Graph>& graph);
// execution of the function. Method is a wrapper around an
// underlying Function that also provides a `self` object.
struct TORCH_API Function {
virtual const std::string& doc_string() const {
static const std::string no_doc_string = "";
return no_doc_string;
}

virtual bool isGraphFunction() const = 0;

virtual void run(Stack& stack) = 0;
Expand Down
12 changes: 10 additions & 2 deletions aten/src/ATen/core/jit_type.h
Expand Up @@ -1989,7 +1989,8 @@ struct CAFFE2_API ClassType : public NamedType {
static ClassTypePtr create(
c10::optional<QualifiedName> qualifiedName,
std::weak_ptr<CompilationUnit> cu,
bool is_module = false);
bool is_module = false,
std::string doc_string = "");

bool operator==(const Type& rhs) const override {
if (auto user_rhs = rhs.cast<ClassType>()) {
Expand Down Expand Up @@ -2175,6 +2176,9 @@ struct CAFFE2_API ClassType : public NamedType {
return constantNames_[slot];
}

const std::string& doc_string() const {
return doc_string_;
}

IValue getConstant(const std::string& name) const;

Expand Down Expand Up @@ -2271,7 +2275,8 @@ struct CAFFE2_API ClassType : public NamedType {
ClassType(
c10::optional<QualifiedName> name,
std::weak_ptr<CompilationUnit> cu,
bool is_module);
bool is_module,
std::string doc_string);

std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
const auto& n = name().value();
Expand Down Expand Up @@ -2306,6 +2311,9 @@ struct CAFFE2_API ClassType : public NamedType {
std::vector<Property> properties_;

bool isModule_ = false;

// Doc string of class.
std::string doc_string_ = "";
};

struct InterfaceType;
Expand Down
12 changes: 7 additions & 5 deletions aten/src/ATen/core/type.cpp
Expand Up @@ -1211,19 +1211,21 @@ InterfaceType::~InterfaceType() = default;
ClassTypePtr ClassType::create(
c10::optional<QualifiedName> qualifiedName,
std::weak_ptr<CompilationUnit> cu,
bool is_module) {
bool is_module,
std::string doc_string) {
return ClassTypePtr(
new ClassType(std::move(qualifiedName), std::move(cu), is_module));
new ClassType(std::move(qualifiedName), std::move(cu), is_module, std::move(doc_string)));
}

ClassType::ClassType(
c10::optional<QualifiedName> name,
std::weak_ptr<CompilationUnit> cu,
bool is_module = false)
bool is_module = false,
std::string doc_string = "")
: NamedType(TypeKind::ClassType, std::move(name)),
compilation_unit_(std::move(cu)),
isModule_(is_module) {
}
isModule_(is_module),
doc_string_(std::move(doc_string)) {}

const std::vector<torch::jit::Function*>& ClassType::methods() const {
return methods_;
Expand Down
42 changes: 42 additions & 0 deletions test/cpp/jit/test_custom_class.cpp
Expand Up @@ -44,5 +44,47 @@ TEST(CustomClassTest, TorchbindIValueAPI) {
test_with_obj(new_stack_ivalue, "boo");
}

class TorchBindTestClass : public torch::jit::CustomClassHolder {
public:
std::string get() {
return "Hello, I am your test custom class";
}
};

constexpr char class_doc_string[] = R"(
I am docstring for TorchBindTestClass
Args:
What is an argument? Oh never mind, I don't take any.
Return:
How would I know? I am just a holder of some meaningless test methods.
)";
constexpr char method_doc_string[] =
"I am docstring for TorchBindTestClass get_with_docstring method";

namespace {
static auto reg =
torch::class_<TorchBindTestClass>(
"_TorchBindTest",
"_TorchBindTestClass",
class_doc_string)
.def("get", &TorchBindTestClass::get)
.def("get_with_docstring", &TorchBindTestClass::get, method_doc_string);

} // namespace

// Tests DocString is properly propagated when defining CustomClasses.
TEST(CustomClassTest, TestDocString) {
auto class_type = getCustomClass(
"__torch__.torch.classes._TorchBindTest._TorchBindTestClass");
AT_ASSERT(class_type);
AT_ASSERT(class_type->doc_string() == class_doc_string);

AT_ASSERT(class_type->getMethod("get").doc_string().empty());
AT_ASSERT(
class_type->getMethod("get_with_docstring").doc_string() ==
method_doc_string);
}

} // namespace jit
} // namespace torch
5 changes: 4 additions & 1 deletion torch/csrc/jit/python/python_custom_class.cpp
Expand Up @@ -28,7 +28,10 @@ void initPythonCustomClassBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();

py::class_<ScriptClass>(m, "ScriptClass")
.def("__call__", &ScriptClass::__call__);
.def("__call__", &ScriptClass::__call__)
.def_property_readonly("__doc__", [](const ScriptClass& self) {
return self.class_type_.type_->expect<ClassType>()->doc_string();
});

// This function returns a ScriptClass that wraps the constructor
// of the given class, specified by the qualified name passed in.
Expand Down
8 changes: 6 additions & 2 deletions torch/csrc/jit/python/script_init.cpp
Expand Up @@ -1187,9 +1187,13 @@ void initJitScriptBindings(PyObject* module) {
"name",
[](const StrongFunctionPtr& self) { return self.function_->name(); })
.def_property_readonly(
"qualified_name", [](const StrongFunctionPtr& self) {
"qualified_name",
[](const StrongFunctionPtr& self) {
return self.function_->qualname().qualifiedName();
});
})
.def_property_readonly("__doc__", [](const StrongFunctionPtr& self) {
return self.function_->doc_string();
});

py::class_<Method>(m, "ScriptMethod", py::dynamic_attr())
.def(
Expand Down
22 changes: 12 additions & 10 deletions torch/custom_class.h
Expand Up @@ -58,14 +58,16 @@ class class_ {
/// see this class exposed as in Python and TorchScript. For example, if
/// you pass `foo` as the namespace name and `Bar` as the className, the
/// class will appear as `torch.classes.foo.Bar` in Python and TorchScript
explicit class_(const std::string& namespaceName, const std::string& className) {
explicit class_(const std::string& namespaceName, const std::string& className, std::string doc_string = "") {
detail::checkValidIdent(namespaceName, "Namespace name");
detail::checkValidIdent(className, "Class name");
qualClassName = std::string("__torch__.torch.classes.") + namespaceName + "." + className;

classTypePtr = at::ClassType::create(
c10::QualifiedName(qualClassName),
std::weak_ptr<jit::CompilationUnit>());
std::weak_ptr<jit::CompilationUnit>(),
/*is_module=*/false,
std::move(doc_string));
classTypePtr->addAttribute("capsule", at::CapsuleType::get());

c10::getCustomClassTypeMap().insert(
Expand All @@ -81,15 +83,15 @@ class class_ {
/// `torch::init<int, std::string>()` would register a two-argument constructor
/// taking an `int` and a `std::string` as argument.
template <typename... Types>
class_& def(detail::types<void, Types...>) { // Used in combination with
class_& def(detail::types<void, Types...>, std::string doc_string = "") { // Used in combination with
// torch::init<...>()
auto func = [](c10::tagged_capsule<CurClass> self, Types... args) {
auto classObj = c10::make_intrusive<CurClass>(args...);
auto object = self.ivalue.toObject();
object->setSlot(0, c10::IValue::make_capsule(std::move(classObj)));
};

defineMethod("__init__", std::move(func));
defineMethod("__init__", std::move(func), std::move(doc_string));
return *this;
}

Expand All @@ -112,18 +114,18 @@ class class_ {
/// // do something
/// })
template <typename Func>
class_& def(std::string name, Func f) {
class_& def(std::string name, Func f, std::string doc_string = "") {
auto wrapped_f = detail::wrap_func<CurClass, Func>(std::move(f));
defineMethod(std::move(name), std::move(wrapped_f));
defineMethod(std::move(name), std::move(wrapped_f), std::move(doc_string));
return *this;
}

/// This is an unsafe method registration API added for adding custom JIT backend support via custom
/// C++ classes. It is not for general purpose use.
class_& _def_unboxed(std::string name, std::function<void(jit::Stack&)> func, c10::FunctionSchema schema) {
class_& _def_unboxed(std::string name, std::function<void(jit::Stack&)> func, c10::FunctionSchema schema, std::string doc_string = "") {
auto qualMethodName = qualClassName + "." + name;
auto method = std::make_unique<jit::BuiltinOpFunction>(
qualMethodName, std::move(schema), std::move(func));
qualMethodName, std::move(schema), std::move(func), std::move(doc_string));
classTypePtr->addMethod(method.get());
registerCustomClassMethod(std::move(method));
return *this;
Expand Down Expand Up @@ -228,7 +230,7 @@ class class_ {

private:
template <typename Func>
void defineMethod(std::string name, Func func) {
void defineMethod(std::string name, Func func, std::string doc_string = "") {
auto qualMethodName = qualClassName + "." + name;
auto schema = c10::inferFunctionSchemaSingleReturn<Func>(std::move(name), "");

Expand All @@ -241,7 +243,7 @@ class class_ {
detail::BoxedProxy<RetType, Func>()(stack, func);
};
auto method = std::make_unique<jit::BuiltinOpFunction>(
qualMethodName, std::move(schema), std::move(wrapped_func));
qualMethodName, std::move(schema), std::move(wrapped_func), std::move(doc_string));

// Register the method here to keep the Method alive.
// ClassTypes do not hold ownership of their methods (normally it
Expand Down

0 comments on commit f9b9430

Please sign in to comment.