From e7a3922c33f24699afc06a19d4d6ca5b0edfd0b3 Mon Sep 17 00:00:00 2001 From: angelayi Date: Mon, 8 Sep 2025 16:53:49 -0700 Subject: [PATCH] Add custom class tutorial back --- advanced_source/custom_class_pt2.rst | 4 +- advanced_source/custom_classes.rst | 231 ++++++++++++++++++ advanced_source/custom_classes/CMakeLists.txt | 15 ++ .../custom_class_project/CMakeLists.txt | 10 + .../custom_class_project/class.cpp | 132 ++++++++++ .../custom_class_project/custom_test.py | 53 ++++ .../custom_class_project/export_attr.py | 21 ++ .../custom_class_project/save.py | 18 ++ advanced_source/custom_classes/infer.cpp | 20 ++ advanced_source/custom_classes/run.sh | 21 ++ advanced_source/custom_classes/run2.sh | 13 + .../torch_script_custom_classes.rst | 6 - 12 files changed, 536 insertions(+), 8 deletions(-) create mode 100644 advanced_source/custom_classes.rst create mode 100644 advanced_source/custom_classes/CMakeLists.txt create mode 100644 advanced_source/custom_classes/custom_class_project/CMakeLists.txt create mode 100644 advanced_source/custom_classes/custom_class_project/class.cpp create mode 100644 advanced_source/custom_classes/custom_class_project/custom_test.py create mode 100644 advanced_source/custom_classes/custom_class_project/export_attr.py create mode 100644 advanced_source/custom_classes/custom_class_project/save.py create mode 100644 advanced_source/custom_classes/infer.cpp create mode 100755 advanced_source/custom_classes/run.sh create mode 100755 advanced_source/custom_classes/run2.sh delete mode 100644 advanced_source/torch_script_custom_classes.rst diff --git a/advanced_source/custom_class_pt2.rst b/advanced_source/custom_class_pt2.rst index 8579ff1567a..229a94f2ce9 100644 --- a/advanced_source/custom_class_pt2.rst +++ b/advanced_source/custom_class_pt2.rst @@ -3,7 +3,7 @@ Supporting Custom C++ Classes in torch.compile/torch.export This tutorial is a follow-on to the -:doc:`custom C++ classes ` tutorial, and +:doc:`custom C++ classes ` tutorial, and introduces additional steps that are needed to support custom C++ classes in torch.compile/torch.export. @@ -30,7 +30,7 @@ Concretely, there are a few steps: states returned by ``__obj_flatten__``. Here is a breakdown of the diff. Following the guide in -:doc:`Extending TorchScript with Custom C++ Classes `, +:doc:`Extending TorchScript with Custom C++ Classes `, we can create a thread-safe tensor queue and build it. .. code-block:: cpp diff --git a/advanced_source/custom_classes.rst b/advanced_source/custom_classes.rst new file mode 100644 index 00000000000..014bac2eebf --- /dev/null +++ b/advanced_source/custom_classes.rst @@ -0,0 +1,231 @@ +Extending PyTorch with Custom C++ Classes +=============================================== + + +This tutorial introduces an API for binding C++ classes into PyTorch. +The API is very similar to +`pybind11 `_, and most of the concepts will transfer +over if you're familiar with that system. + +Implementing and Binding the Class in C++ +----------------------------------------- + +For this tutorial, we are going to define a simple C++ class that maintains persistent +state in a member variable. + +.. literalinclude:: ../advanced_source/custom_classes/custom_class_project/class.cpp + :language: cpp + :start-after: BEGIN class + :end-before: END class + +There are several things to note: + +- ``torch/custom_class.h`` is the header you need to include to extend PyTorch + with your custom class. +- Notice that whenever we are working with instances of the custom + class, we do it via instances of ``c10::intrusive_ptr<>``. Think of ``intrusive_ptr`` + as a smart pointer like ``std::shared_ptr``, but the reference count is stored + directly in the object, as opposed to a separate metadata block (as is done in + ``std::shared_ptr``. ``torch::Tensor`` internally uses the same pointer type; + and custom classes have to also use this pointer type so that we can + consistently manage different object types. +- The second thing to notice is that the user-defined class must inherit from + ``torch::CustomClassHolder``. This ensures that the custom class has space to + store the reference count. + +Now let's take a look at how we will make this class visible to PyTorch, a process called +*binding* the class: + +.. literalinclude:: ../advanced_source/custom_classes/custom_class_project/class.cpp + :language: cpp + :start-after: BEGIN binding + :end-before: END binding + :append: + ; + } + + + +Building the Example as a C++ Project With CMake +------------------------------------------------ + +Now, we're going to build the above C++ code with the `CMake +`_ build system. First, take all the C++ code +we've covered so far and place it in a file called ``class.cpp``. +Then, write a simple ``CMakeLists.txt`` file and place it in the +same directory. Here is what ``CMakeLists.txt`` should look like: + +.. literalinclude:: ../advanced_source/custom_classes/custom_class_project/CMakeLists.txt + :language: cmake + +Also, create a ``build`` directory. Your file tree should look like this:: + + custom_class_project/ + class.cpp + CMakeLists.txt + build/ + +Go ahead and invoke cmake and then make to build the project: + +.. code-block:: shell + + $ cd build + $ cmake -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" .. + -- The C compiler identification is GNU 7.3.1 + -- The CXX compiler identification is GNU 7.3.1 + -- Check for working C compiler: /opt/rh/devtoolset-7/root/usr/bin/cc + -- Check for working C compiler: /opt/rh/devtoolset-7/root/usr/bin/cc -- works + -- Detecting C compiler ABI info + -- Detecting C compiler ABI info - done + -- Detecting C compile features + -- Detecting C compile features - done + -- Check for working CXX compiler: /opt/rh/devtoolset-7/root/usr/bin/c++ + -- Check for working CXX compiler: /opt/rh/devtoolset-7/root/usr/bin/c++ -- works + -- Detecting CXX compiler ABI info + -- Detecting CXX compiler ABI info - done + -- Detecting CXX compile features + -- Detecting CXX compile features - done + -- Looking for pthread.h + -- Looking for pthread.h - found + -- Looking for pthread_create + -- Looking for pthread_create - not found + -- Looking for pthread_create in pthreads + -- Looking for pthread_create in pthreads - not found + -- Looking for pthread_create in pthread + -- Looking for pthread_create in pthread - found + -- Found Threads: TRUE + -- Found torch: /torchbind_tutorial/libtorch/lib/libtorch.so + -- Configuring done + -- Generating done + -- Build files have been written to: /torchbind_tutorial/build + $ make -j + Scanning dependencies of target custom_class + [ 50%] Building CXX object CMakeFiles/custom_class.dir/class.cpp.o + [100%] Linking CXX shared library libcustom_class.so + [100%] Built target custom_class + +What you'll find is there is now (among other things) a dynamic library +file present in the build directory. On Linux, this is probably named +``libcustom_class.so``. So the file tree should look like:: + + custom_class_project/ + class.cpp + CMakeLists.txt + build/ + libcustom_class.so + +Using the C++ Class from Python +----------------------------------------------- + +Now that we have our class and its registration compiled into an ``.so`` file, +we can load that `.so` into Python and try it out. Here's a script that +demonstrates that: + +.. literalinclude:: ../advanced_source/custom_classes/custom_class_project/custom_test.py + :language: python + + +Defining Serialization/Deserialization Methods for Custom C++ Classes +--------------------------------------------------------------------- + +If you try to save a ``ScriptModule`` with a custom-bound C++ class as +an attribute, you'll get the following error: + +.. literalinclude:: ../advanced_source/custom_classes/custom_class_project/export_attr.py + :language: python + +.. code-block:: shell + + $ python export_attr.py + RuntimeError: Cannot serialize custom bound C++ class __torch__.torch.classes.my_classes.MyStackClass. Please define serialization methods via def_pickle for this class. (pushIValueImpl at ../torch/csrc/jit/pickler.cpp:128) + +This is because PyTorch cannot automatically figure out what information +save from your C++ class. You must specify that manually. The way to do that +is to define ``__getstate__`` and ``__setstate__`` methods on the class using +the special ``def_pickle`` method on ``class_``. + +.. note:: + The semantics of ``__getstate__`` and ``__setstate__`` are + equivalent to that of the Python pickle module. You can + `read more `_ + about how we use these methods. + +Here is an example of the ``def_pickle`` call we can add to the registration of +``MyStackClass`` to include serialization methods: + +.. literalinclude:: ../advanced_source/custom_classes/custom_class_project/class.cpp + :language: cpp + :start-after: BEGIN def_pickle + :end-before: END def_pickle + +.. note:: + We take a different approach from pybind11 in the pickle API. Whereas pybind11 + as a special function ``pybind11::pickle()`` which you pass into ``class_::def()``, + we have a separate method ``def_pickle`` for this purpose. This is because the + name ``torch::jit::pickle`` was already taken, and we didn't want to cause confusion. + +Once we have defined the (de)serialization behavior in this way, our script can +now run successfully: + +.. code-block:: shell + + $ python ../export_attr.py + testing + +Defining Custom Operators that Take or Return Bound C++ Classes +--------------------------------------------------------------- + +Once you've defined a custom C++ class, you can also use that class +as an argument or return from a custom operator (i.e. free functions). Suppose +you have the following free function: + +.. literalinclude:: ../advanced_source/custom_classes/custom_class_project/class.cpp + :language: cpp + :start-after: BEGIN free_function + :end-before: END free_function + +You can register it running the following code inside your ``TORCH_LIBRARY`` +block: + +.. literalinclude:: ../advanced_source/custom_classes/custom_class_project/class.cpp + :language: cpp + :start-after: BEGIN def_free + :end-before: END def_free + +Once this is done, you can use the op like the following example: + +.. code-block:: python + + class TryCustomOp(torch.nn.Module): + def __init__(self): + super(TryCustomOp, self).__init__() + self.f = torch.classes.my_classes.MyStackClass(["foo", "bar"]) + + def forward(self): + return torch.ops.my_classes.manipulate_instance(self.f) + +.. note:: + + Registration of an operator that takes a C++ class as an argument requires that + the custom class has already been registered. You can enforce this by + making sure the custom class registration and your free function definitions + are in the same ``TORCH_LIBRARY`` block, and that the custom class + registration comes first. In the future, we may relax this requirement, + so that these can be registered in any order. + + +Conclusion +---------- + +This tutorial walked you through how to expose a C++ class to PyTorch, how to +register its methods, how to use that class from Python, and how to save and +load code using the class and run that code in a standalone C++ process. You +are now ready to extend your PyTorch models with C++ classes that interface +with third party C++ libraries or implement any other use case that requires +the lines between Python and C++ to blend smoothly. + +As always, if you run into any problems or have questions, you can use our +`forum `_ or `GitHub issues +`_ to get in touch. Also, our +`frequently asked questions (FAQ) page +`_ may have helpful information. diff --git a/advanced_source/custom_classes/CMakeLists.txt b/advanced_source/custom_classes/CMakeLists.txt new file mode 100644 index 00000000000..6a1eb3e87fa --- /dev/null +++ b/advanced_source/custom_classes/CMakeLists.txt @@ -0,0 +1,15 @@ +cmake_minimum_required(VERSION 3.1 FATAL_ERROR) +project(infer) + +find_package(Torch REQUIRED) + +add_subdirectory(custom_class_project) + +# Define our library target +add_executable(infer infer.cpp) +set(CMAKE_CXX_STANDARD 14) +# Link against LibTorch +target_link_libraries(infer "${TORCH_LIBRARIES}") +# This is where we link in our libcustom_class code, making our +# custom class available in our binary. +target_link_libraries(infer -Wl,--no-as-needed custom_class) diff --git a/advanced_source/custom_classes/custom_class_project/CMakeLists.txt b/advanced_source/custom_classes/custom_class_project/CMakeLists.txt new file mode 100644 index 00000000000..bb3d41aa997 --- /dev/null +++ b/advanced_source/custom_classes/custom_class_project/CMakeLists.txt @@ -0,0 +1,10 @@ +cmake_minimum_required(VERSION 3.1 FATAL_ERROR) +project(custom_class) + +find_package(Torch REQUIRED) + +# Define our library target +add_library(custom_class SHARED class.cpp) +set(CMAKE_CXX_STANDARD 14) +# Link against LibTorch +target_link_libraries(custom_class "${TORCH_LIBRARIES}") diff --git a/advanced_source/custom_classes/custom_class_project/class.cpp b/advanced_source/custom_classes/custom_class_project/class.cpp new file mode 100644 index 00000000000..dc89a3ecb2e --- /dev/null +++ b/advanced_source/custom_classes/custom_class_project/class.cpp @@ -0,0 +1,132 @@ +// BEGIN class +// This header is all you need to do the C++ portions of this +// tutorial +#include +// This header is what defines the custom class registration +// behavior specifically. script.h already includes this, but +// we include it here so you know it exists in case you want +// to look at the API or implementation. +#include + +#include +#include + +template +struct MyStackClass : torch::CustomClassHolder { + std::vector stack_; + MyStackClass(std::vector init) : stack_(init.begin(), init.end()) {} + + void push(T x) { + stack_.push_back(x); + } + T pop() { + auto val = stack_.back(); + stack_.pop_back(); + return val; + } + + c10::intrusive_ptr clone() const { + return c10::make_intrusive(stack_); + } + + void merge(const c10::intrusive_ptr& c) { + for (auto& elem : c->stack_) { + push(elem); + } + } +}; +// END class + +// BEGIN free_function +c10::intrusive_ptr> manipulate_instance(const c10::intrusive_ptr>& instance) { + instance->pop(); + return instance; +} +// END free_function + +// BEGIN binding +// Notice a few things: +// - We pass the class to be registered as a template parameter to +// `torch::class_`. In this instance, we've passed the +// specialization of the MyStackClass class ``MyStackClass``. +// In general, you cannot register a non-specialized template +// class. For non-templated classes, you can just pass the +// class name directly as the template parameter. +// - The arguments passed to the constructor make up the "qualified name" +// of the class. In this case, the registered class will appear in +// Python and C++ as `torch.classes.my_classes.MyStackClass`. We call +// the first argument the "namespace" and the second argument the +// actual class name. +TORCH_LIBRARY(my_classes, m) { + m.class_>("MyStackClass") + // The following line registers the contructor of our MyStackClass + // class that takes a single `std::vector` argument, + // i.e. it exposes the C++ method `MyStackClass(std::vector init)`. + // Currently, we do not support registering overloaded + // constructors, so for now you can only `def()` one instance of + // `torch::init`. + .def(torch::init>()) + // The next line registers a stateless (i.e. no captures) C++ lambda + // function as a method. Note that a lambda function must take a + // `c10::intrusive_ptr` (or some const/ref version of that) + // as the first argument. Other arguments can be whatever you want. + .def("top", [](const c10::intrusive_ptr>& self) { + return self->stack_.back(); + }) + // The following four lines expose methods of the MyStackClass + // class as-is. `torch::class_` will automatically examine the + // argument and return types of the passed-in method pointers and + // expose these to Python and TorchScript accordingly. Finally, notice + // that we must take the *address* of the fully-qualified method name, + // i.e. use the unary `&` operator, due to C++ typing rules. + .def("push", &MyStackClass::push) + .def("pop", &MyStackClass::pop) + .def("clone", &MyStackClass::clone) + .def("merge", &MyStackClass::merge) +// END binding +#ifndef NO_PICKLE +// BEGIN def_pickle + // class_<>::def_pickle allows you to define the serialization + // and deserialization methods for your C++ class. + // Currently, we only support passing stateless lambda functions + // as arguments to def_pickle + .def_pickle( + // __getstate__ + // This function defines what data structure should be produced + // when we serialize an instance of this class. The function + // must take a single `self` argument, which is an intrusive_ptr + // to the instance of the object. The function can return + // any type that is supported as a return value of the TorchScript + // custom operator API. In this instance, we've chosen to return + // a std::vector as the salient data to preserve + // from the class. + [](const c10::intrusive_ptr>& self) + -> std::vector { + return self->stack_; + }, + // __setstate__ + // This function defines how to create a new instance of the C++ + // class when we are deserializing. The function must take a + // single argument of the same type as the return value of + // `__getstate__`. The function must return an intrusive_ptr + // to a new instance of the C++ class, initialized however + // you would like given the serialized state. + [](std::vector state) + -> c10::intrusive_ptr> { + // A convenient way to instantiate an object and get an + // intrusive_ptr to it is via `make_intrusive`. We use + // that here to allocate an instance of MyStackClass + // and call the single-argument std::vector + // constructor with the serialized state. + return c10::make_intrusive>(std::move(state)); + }); +// END def_pickle +#endif // NO_PICKLE + +// BEGIN def_free + m.def( + "manipulate_instance(__torch__.torch.classes.my_classes.MyStackClass x) -> __torch__.torch.classes.my_classes.MyStackClass Y", + manipulate_instance + ); +// END def_free +} diff --git a/advanced_source/custom_classes/custom_class_project/custom_test.py b/advanced_source/custom_classes/custom_class_project/custom_test.py new file mode 100644 index 00000000000..1deda445310 --- /dev/null +++ b/advanced_source/custom_classes/custom_class_project/custom_test.py @@ -0,0 +1,53 @@ +import torch + +# `torch.classes.load_library()` allows you to pass the path to your .so file +# to load it in and make the custom C++ classes available to both Python and +# TorchScript +torch.classes.load_library("build/libcustom_class.so") +# You can query the loaded libraries like this: +print(torch.classes.loaded_libraries) +# prints {'/custom_class_project/build/libcustom_class.so'} + +# We can find and instantiate our custom C++ class in python by using the +# `torch.classes` namespace: +# +# This instantiation will invoke the MyStackClass(std::vector init) +# constructor we registered earlier +s = torch.classes.my_classes.MyStackClass(["foo", "bar"]) + +# We can call methods in Python +s.push("pushed") +assert s.pop() == "pushed" + +# Test custom operator +s.push("pushed") +torch.ops.my_classes.manipulate_instance(s) # acting as s.pop() +assert s.top() == "bar" + +# Returning and passing instances of custom classes works as you'd expect +s2 = s.clone() +s.merge(s2) +for expected in ["bar", "foo", "bar", "foo"]: + assert s.pop() == expected + +# We can also use the class in TorchScript +# For now, we need to assign the class's type to a local in order to +# annotate the type on the TorchScript function. This may change +# in the future. +MyStackClass = torch.classes.my_classes.MyStackClass + + +@torch.jit.script +def do_stacks(s: MyStackClass): # We can pass a custom class instance + # We can instantiate the class + s2 = torch.classes.my_classes.MyStackClass(["hi", "mom"]) + s2.merge(s) # We can call a method on the class + # We can also return instances of the class + # from TorchScript function/methods + return s2.clone(), s2.top() + + +stack, top = do_stacks(torch.classes.my_classes.MyStackClass(["wow"])) +assert top == "wow" +for expected in ["wow", "mom", "hi"]: + assert stack.pop() == expected diff --git a/advanced_source/custom_classes/custom_class_project/export_attr.py b/advanced_source/custom_classes/custom_class_project/export_attr.py new file mode 100644 index 00000000000..9999d5c8183 --- /dev/null +++ b/advanced_source/custom_classes/custom_class_project/export_attr.py @@ -0,0 +1,21 @@ +# export_attr.py +import torch + +torch.classes.load_library('build/libcustom_class.so') + + +class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.stack = torch.classes.my_classes.MyStackClass(["just", "testing"]) + + def forward(self, s: str) -> str: + return self.stack.pop() + s + + +scripted_foo = torch.jit.script(Foo()) + +scripted_foo.save('foo.pt') +loaded = torch.jit.load('foo.pt') + +print(loaded.stack.pop()) diff --git a/advanced_source/custom_classes/custom_class_project/save.py b/advanced_source/custom_classes/custom_class_project/save.py new file mode 100644 index 00000000000..8826f95da7c --- /dev/null +++ b/advanced_source/custom_classes/custom_class_project/save.py @@ -0,0 +1,18 @@ +import torch + +torch.classes.load_library('build/libcustom_class.so') + + +class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, s: str) -> str: + stack = torch.classes.my_classes.MyStackClass(["hi", "mom"]) + return stack.pop() + s + + +scripted_foo = torch.jit.script(Foo()) +print(scripted_foo.graph) + +scripted_foo.save('foo.pt') diff --git a/advanced_source/custom_classes/infer.cpp b/advanced_source/custom_classes/infer.cpp new file mode 100644 index 00000000000..1ca5b002383 --- /dev/null +++ b/advanced_source/custom_classes/infer.cpp @@ -0,0 +1,20 @@ +#include + +#include +#include + +int main(int argc, const char* argv[]) { + torch::jit::Module module; + try { + // Deserialize the ScriptModule from a file using torch::jit::load(). + module = torch::jit::load("foo.pt"); + } + catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + return -1; + } + + std::vector inputs = {"foobarbaz"}; + auto output = module.forward(inputs).toString(); + std::cout << output->string() << std::endl; +} diff --git a/advanced_source/custom_classes/run.sh b/advanced_source/custom_classes/run.sh new file mode 100755 index 00000000000..52c59581309 --- /dev/null +++ b/advanced_source/custom_classes/run.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +set -ex + +rm -rf build +rm -rf custom_class_project/build + +pushd custom_class_project + mkdir build + (cd build && cmake CXXFLAGS="-DNO_PICKLE" -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" ..) + (cd build && make) + python custom_test.py + python save.py + ! python export_attr.py +popd + +mkdir build +(cd build && cmake -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" ..) +(cd build && make) +mv custom_class_project/foo.pt build/foo.pt +(cd build && ./infer) diff --git a/advanced_source/custom_classes/run2.sh b/advanced_source/custom_classes/run2.sh new file mode 100755 index 00000000000..d4ef0101a83 --- /dev/null +++ b/advanced_source/custom_classes/run2.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +set -ex + +rm -rf build +rm -rf custom_class_project/build + +pushd custom_class_project + mkdir build + (cd build && cmake -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" ..) + (cd build && make) + python export_attr.py +popd diff --git a/advanced_source/torch_script_custom_classes.rst b/advanced_source/torch_script_custom_classes.rst deleted file mode 100644 index 01bc497d38e..00000000000 --- a/advanced_source/torch_script_custom_classes.rst +++ /dev/null @@ -1,6 +0,0 @@ -.. - TODO(gmagogsfm): Replace/delete this document by 2.9 release. https://github.com/pytorch/tutorials/issues/3456 - -.. warning:: - TorchScript is deprecated, please use - `torch.export `__ instead. \ No newline at end of file