Skip to content

Commit

Permalink
Initial torchbind prototype (#21098)
Browse files Browse the repository at this point in the history
Summary:
I have some test code in there as well, along with a script "test_libtorch" to run it. You'll need to modify `test_libtorch` to point to where you have `pytorch` built. I currently require that `pybind11` is included as a subdirectory of the test, but added it to the `.gitignore` to make this reviewable.

Currently, something like this works:
```cpp
struct Foo {
  int x, y;
  Foo(): x(2), y(5){}
  Foo(int x_, int y_) : x(x_), y(y_) {}
  void display() {
    cout<<"x: "<<x<<' '<<"y: "<<y<<endl;
  }
  int64_t add(int64_t z) {
    return (x+y)*z;
  }
};
static auto test = torch::jit::class_<Foo>("Foo")
                    .def(torch::jit::init<int64_t, int64_t>())
                    .def("display", &Foo::display)
                    .def("add", &Foo::add)
                    .def("combine", &Foo::combine);

```
with
```py
torch.jit.script
def f(x):
    val = torch._C.Foo(5, 3)
    val.display()
    print(val.add(3))
```
results in
```
x: 5 y: 3
24
```

Current issues:
- [x] The python class created by torchscript doesn't interactly properly with the surrounding code.
```
torch.jit.script
def f(x):
    val = torch._C.Foo(5, 3)
    return val
```
- [x] Doesn't properly take in non-pointer classes. Can't define this function signature in cpp (We don't want to support this I believe).
```cpp
  void combine(Foo x) {
```

- [x] Has some issues with memory for blobs when constructing multiple objects (fix constant propagation pass to not treat capsules as the same object).
```py
torch.jit.script
def f(x):
    val = torch._C.Foo(5, 3)
    val2 = torch._C.Foo(100, 0)
    val.display()
    print(val.add(3))
```
- [ ] Can't define multiple constructors (need to define overload string. Currently not possible since we don't support overloaded methods).
- [x] `init` is a little bit different syntax than `pybind`. `.init<...>()` instead of `.def(py::init<>())`
- [x] I couldn't figure out how to add some files into the build so they'd be copied to the `include/` directories, so I symlinked them manually.
- [ ] Currently, the conversion from Python into Torchscript doesn't work.
- [ ] Torchbind also currently requires Python/Pybind dependency. Fixing this would probably involve some kind of macro to bind into Python when possible.
- [ ] We pass back into Python by value, currently. There's no way of passing by reference.
- [x] Currently can only register one method with the same type signature. This is because we create a `static auto opRegistry`, and the function is templated on the type signature.

Somewhat blocked on #21177. We currently use some structures that will be refactored by his PR (namely `return_type_to_ivalue` and `ivalue_to_arg_type`.
Pull Request resolved: #21098

Differential Revision: D16634872

Pulled By: Chillee

fbshipit-source-id: 1408bb89ea649c27d560df59e2cf9920467fe1de
  • Loading branch information
Chillee authored and facebook-github-bot committed Aug 3, 2019
1 parent 4e6e11c commit f81db8a
Show file tree
Hide file tree
Showing 24 changed files with 607 additions and 20 deletions.
1 change: 1 addition & 0 deletions .jenkins/pytorch/macos-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ test_custom_script_ops() {

# Run tests Python-side and export a script module.
python test_custom_ops.py -v
python test_custom_classes.py -v
python model.py --export-script-module=model.pt
# Run tests C++-side and load the exported script module.
build/test_custom_ops ./model.pt
Expand Down
1 change: 1 addition & 0 deletions .jenkins/pytorch/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ test_custom_script_ops() {
cp -a "$CUSTOM_OP_BUILD" build
# Run tests Python-side and export a script module.
python test_custom_ops.py -v
python test_custom_classes.py -v
python model.py --export-script-module=model.pt
# Run tests C++-side and load the exported script module.
build/test_custom_ops ./model.pt
Expand Down
2 changes: 2 additions & 0 deletions .jenkins/pytorch/win-test-helpers/test_custom_script_ops.bat
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
call %SCRIPT_HELPERS_DIR%\setup_pytorch_env.bat

git submodule update --init --recursive third_party/pybind11
cd test\custom_operator

:: Build the custom operator library.
Expand All @@ -23,6 +24,7 @@ popd

:: Run tests Python-side and export a script module.
python test_custom_ops.py -v
python test_custom_classes.py -v
python model.py --export-script-module="build/model.pt"
:: Run tests C++-side and load the exported script module.
cd build
Expand Down
13 changes: 13 additions & 0 deletions aten/src/ATen/core/ivalue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
return printList(out, v.toTensorList(), "[", "]");
case IValue::Tag::Blob:
return out << *v.toBlob();
case IValue::Tag::Capsule:
return out << "Capsule";
case IValue::Tag::GenericList:
return printList(out, v.toGenericList(), "[", "]");
case IValue::Tag::Future:
Expand Down Expand Up @@ -170,4 +172,15 @@ std::vector<std::pair<IValue, IValue>> iterationOrder(const c10::Dict<IValue, IV
return ordered;
}

std::unordered_map<std::string, c10::StrongTypePtr>& getCustomClassTypeMap() {
static std::unordered_map<std::string, c10::StrongTypePtr> tmap;
return tmap;
}

std::unordered_map<std::string, std::function<PyObject*(void*)>>&
getClassConverter() {
static std::unordered_map<std::string, std::function<PyObject*(void*)>>
classConverter;
return classConverter;
}
} // namespace c10
34 changes: 33 additions & 1 deletion aten/src/ATen/core/ivalue.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
#include <ATen/core/blob.h>
#include <c10/util/intrusive_ptr.h>
#include <ATen/core/Tensor.h>
#include <torch/csrc/WindowsTorchApiMacro.h>

namespace torch {
namespace jit {
class CustomClassHolder : public c10::intrusive_ptr_target {};
struct Function;
namespace script {
struct CompilationUnit;
Expand Down Expand Up @@ -49,8 +51,10 @@ struct Object;
_(GenericDict) \
_(Future) \
_(Device) \
_(Object) \
_(Uninitialized) \
_(Object)
_(Capsule) \


struct CAFFE2_API IValue final {
IValue() : payload{0}, tag(Tag::None), is_intrusive_ptr(false) {}
Expand Down Expand Up @@ -148,6 +152,14 @@ struct CAFFE2_API IValue final {
c10::intrusive_ptr<caffe2::Blob> toBlob() &&;
c10::intrusive_ptr<caffe2::Blob> toBlob() const &;

// Capsule
IValue(intrusive_ptr<torch::jit::CustomClassHolder> blob);
bool isCapsule() const {
return Tag::Capsule == tag;
}
c10::intrusive_ptr<torch::jit::CustomClassHolder> toCapsule() &&;
c10::intrusive_ptr<torch::jit::CustomClassHolder> toCapsule() const &;

// Tuple
IValue(c10::intrusive_ptr<ivalue::Tuple> v);
bool isTuple() const { return Tag::Tuple == tag; }
Expand Down Expand Up @@ -564,6 +576,26 @@ struct StrongTypePtr {
std::shared_ptr<torch::jit::script::CompilationUnit> cu_;
std::shared_ptr<ClassType> type_;
};

TORCH_API std::unordered_map<std::string, c10::StrongTypePtr>& getCustomClassTypeMap();
template<typename T>
c10::StrongTypePtr getCustomClassType() {
auto tmap = c10::getCustomClassTypeMap();
auto res = tmap.find(typeid(T).name());
if (res == tmap.end()) {
throw c10::Error("Can't find class id in custom class type map", "");
}
return res->second;
}

template<typename T>
inline bool isCustomClassRegistered() {
auto tmap = c10::getCustomClassTypeMap();
return tmap.find(typeid(T).name()) != tmap.end();
}

TORCH_API std::unordered_map<std::string, std::function<PyObject*(void*)>>&
getClassConverter();
}

#include <ATen/core/ivalue_inl.h>
95 changes: 95 additions & 0 deletions aten/src/ATen/core/ivalue_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,21 @@ struct IValue;
struct ClassType;
struct TupleType;

// For custom class __init__ registration, we need to pass in a function
// that looks like this: [](IValue x, args...)

// However, kernel_functor.h automatically sets the input types of the function
// by introspecting the types of the functor (which is IValue in this case).
// However, we need the type it binds to be Foo.

// Instead, we pass in a lambda [](ivalue_holder<CurClass> x, args...) from
// which getTypePtr can recover the original class pointer.

template <typename TaggedCapsuleType>
struct tagged_capsule {
IValue ivalue;
};

template<class T, class NullType>
c10::intrusive_ptr<T, NullType> IValue::moveToIntrusivePtr() {
auto t = c10::intrusive_ptr<T, NullType>::reclaim(static_cast<T*>(payload.as_intrusive_ptr));
Expand All @@ -38,6 +53,11 @@ c10::intrusive_ptr<T, NullType> IValue::toIntrusivePtr() const {
return p;
}

template<class T, class U>
intrusive_ptr<T> static_intrusive_pointer_cast(intrusive_ptr<U> r) {
return intrusive_ptr<T>::reclaim(static_cast<T*>(r.release()));
}

inline c10::intrusive_ptr<ivalue::Future> IValue::toFuture() && {
AT_ASSERT(isFuture(), "Expected Future but got ", tagKind());
return moveToIntrusivePtr<ivalue::Future>();
Expand Down Expand Up @@ -78,6 +98,14 @@ inline c10::intrusive_ptr<caffe2::Blob> IValue::toBlob() const & {
AT_ASSERT(isBlob(), "Expected Blob but got ", tagKind());
return toIntrusivePtr<caffe2::Blob>();;
}
inline c10::intrusive_ptr<torch::jit::CustomClassHolder> IValue::toCapsule() && {
TORCH_INTERNAL_ASSERT(isCapsule());
return moveToIntrusivePtr<torch::jit::CustomClassHolder>();
}
inline c10::intrusive_ptr<torch::jit::CustomClassHolder> IValue::toCapsule() const & {
TORCH_INTERNAL_ASSERT(isCapsule());
return toIntrusivePtr<torch::jit::CustomClassHolder>();
}

namespace ivalue {

Expand Down Expand Up @@ -430,6 +458,23 @@ std::vector<Elem> generic_to(
return result;
}

template <typename T>
T generic_to(
IValue ivalue,
_fake_type<T>) {
using ElemType = typename std::remove_pointer<T>::type::element_type;
auto obj = ivalue.toObject();
auto capsule = obj->getSlot(0);
return c10::static_intrusive_pointer_cast<ElemType>(capsule.toCapsule());
}

template <typename T>
tagged_capsule<T> generic_to(
IValue ivalue,
_fake_type<tagged_capsule<T>>) {
return tagged_capsule<T>{ivalue};
}

template <typename Elem>
c10::List<Elem> generic_to(
IValue ivalue,
Expand Down Expand Up @@ -640,6 +685,10 @@ inline IValue::IValue(c10::intrusive_ptr<ivalue::Object> v)
: tag(Tag::Object), is_intrusive_ptr(true) {
payload.as_intrusive_ptr = v.release();
}
inline IValue::IValue(c10::intrusive_ptr<torch::jit::CustomClassHolder> v)
: tag(Tag::Capsule), is_intrusive_ptr(true) {
payload.as_intrusive_ptr = v.release();
}
inline IValue::IValue(c10::intrusive_ptr<ivalue::Future> v)
: tag(Tag::Future), is_intrusive_ptr(true) {
payload.as_intrusive_ptr = v.release();
Expand Down Expand Up @@ -687,4 +736,50 @@ inline bool IValue::isSameIdentity(const IValue& rhs) const {
}
}

namespace ivalue {
namespace detail {
// This code allows us to template on a function based on whether IValue has a
// constructor for it. Specifically, has_constructor<T>{} inherits from std::true_type if
// IValue(T) compiles, and inherits from std::false_type if IValue(T) doesn't.
// We use it for calling the IValue constructor for `from` if it exists, and otherwise
// attempt to use our custom class code.
template<class> struct type_sink { typedef void type; };
template<class T> using type_sink_t = typename type_sink<T>::type;
template<class T, class=void> struct has_constructor : std::false_type {}; \
template<class T> struct has_constructor<
T,
type_sink_t< decltype( IValue(std::declval<T>())) >
>: std::true_type {};

template <typename T>
IValue from_(T x, std::true_type) {
return IValue(x);
}
template <typename T>
IValue from_(c10::intrusive_ptr<T> x, std::false_type) {
using inputType = c10::intrusive_ptr<T>;
if (!isCustomClassRegistered<inputType>()) {
throw c10::Error("Trying to return a class that we don't support and isn't a registered custom class.", "");
}
auto res = getCustomClassType<inputType>();
auto retObject = ivalue::Object::create(res->second, 1);
auto objPtr = c10::static_intrusive_pointer_cast<torch::jit::CustomClassHolder>(x);

retObject->setSlot(0, IValue(objPtr));
auto resIVal = IValue(std::move(retObject));
return resIVal;
}
template <typename T>
IValue from_(T x, std::false_type) {
static_assert(guts::false_t<T>::value, "You are calling from with a type that it doesn't support, and isn't a potential custom class (ie: is an intrusive_ptr)");
return IValue();
}
}

template <typename T>
IValue from(T x) {
return detail::from_(x, detail::has_constructor<T>{});
}

}
} // namespace c10
37 changes: 33 additions & 4 deletions aten/src/ATen/core/jit_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <memory>
#include <type_traits>

struct ClassType;
namespace torch {
namespace jit {
struct Function;
Expand Down Expand Up @@ -48,7 +49,8 @@ using OptNameList = c10::optional<std::vector<std::string>>;
_(ProfiledTensorType) \
_(DeviceObjType) \
_(FunctionType) \
_(ClassType)
_(ClassType) \
_(CapsuleType)

enum class TypeKind {
#define DEFINE_TYPE(T) T,
Expand Down Expand Up @@ -1304,6 +1306,28 @@ struct VarType : public Type {
std::string name_;
};

struct CapsuleType;
using CapsuleTypePtr = std::shared_ptr<CapsuleType>;
// This type represents a Python Capsule
struct CAFFE2_API CapsuleType : public Type {
static CapsuleTypePtr create() {
return CapsuleTypePtr(new CapsuleType()); // NOLINT(modernize-make-shared)
}
DEFINE_IS_SUBCLASS(CapsuleType);
bool operator==(const Type& rhs) const override {
return rhs.kind() == kind();
}
std::string str() const override {
return "Capsule";
}
static const TypeKind Kind = TypeKind::CapsuleType;
// global singleton
static CapsuleTypePtr get();
private:
CapsuleType()
: Type(TypeKind::CapsuleType) {}
};

CAFFE2_API std::ostream& operator<<(std::ostream& out, const Type& t);
CAFFE2_API std::ostream& operator<<(std::ostream& out, const VaryingShape& t);
// what is the type, ignoring extra size/shape information?
Expand Down Expand Up @@ -1359,9 +1383,13 @@ CAFFE2_API c10::optional<TypePtr> unifyTypes(
namespace detail {
template <typename T>
struct getTypePtr_ final {
static_assert(
guts::false_t<T>::value,
"Type could not be converted to any of the known types.");
static TypePtr call() {
if (!isCustomClassRegistered<T>()) {
throw c10::Error("Type could not be converted to any of the known types.", "");
}
auto res = getCustomClassType<T>();
return std::dynamic_pointer_cast<Type>(res.type_);
}
};

template <>
Expand Down Expand Up @@ -1633,4 +1661,5 @@ struct CAFFE2_API ClassType : public NamedType {
// List of methods associated with this class.
std::vector<Function*> methods_;
};

} // namespace c10
13 changes: 10 additions & 3 deletions aten/src/ATen/core/op_registration/kernel_functor.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <ATen/core/op_registration/infer_schema.h>
#include <ATen/core/ivalue.h>

namespace c10 {
/**
Expand Down Expand Up @@ -37,7 +38,10 @@ namespace detail {
>;

template<class T, bool AllowDeprecatedTypes, class Enable = void> struct assert_is_valid_input_type {
static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported input type.");
assert_is_valid_input_type() {
auto tmap = c10::getCustomClassTypeMap();
TORCH_CHECK(c10::isCustomClassRegistered<T>(), "Tried to use undefined class as input argument");
}
};

template<class T, bool AllowDeprecatedTypes>
Expand Down Expand Up @@ -98,7 +102,10 @@ namespace detail {
};

template<class T, bool AllowDeprecatedTypes, class Enable = void> struct assert_is_valid_output_type {
static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported output type.");
assert_is_valid_output_type() {
auto tmap = getCustomClassTypeMap();
TORCH_CHECK(c10::isCustomClassRegistered<T>(), "Tried to use undefined class as output");
}
};

template<class T, bool AllowDeprecatedTypes>
Expand Down Expand Up @@ -170,7 +177,7 @@ namespace detail {
template<class T, bool AllowDeprecatedTypes>
IValue return_to_ivalue(T&& v) {
assert_is_valid_output_type<T, AllowDeprecatedTypes>();
return IValue(std::move(v));
return c10::ivalue::from(v);
}

template<class Functor, bool AllowDeprecatedTypes, size_t... ivalue_arg_indices>
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/core/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ OptionalTypePtr OptionalType::ofTensor() {
static auto value = OptionalType::create(TensorType::get());
return value;
}
CapsuleTypePtr CapsuleType::get() {
static auto value = CapsuleType::create();
return value;
}
ListTypePtr ListType::ofTensors() {
static auto value = ListType::create(TensorType::get());
return value;
Expand Down
Loading

0 comments on commit f81db8a

Please sign in to comment.