Skip to content

Commit

Permalink
Making ops c10-full: list of optional tensors (#49138)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #49138

See for details: https://fb.quip.com/QRtJAin66lPN

We need to model optional types explicitly, mostly for schema inference. So we cannot pass a `Tensor?[]` as `ArrayRef<Tensor>`, instead we need to pass it as an optional type. This PR changes it to `torch::List<c10::optional<Tensor>>`. It also makes the ops c10-full that were blocked by this.

## Backwards Compatibility

- This should not break the Python API because the representation in Python is the same and python_arg_parser just transforms the python list into a `List<optional<Tensor>>` instead of into a `List<Tensor>`.
- This should not break serialized models because there's some logic that allows loading a serialized `List<Tensor>` as `List<optional<Tensor>>`, see https://github.com/pytorch/pytorch/pull/49138/files#diff-9315f5dd045f47114c677174dcaa2f982721233eee1aa19068a42ff3ef775315R57
- This will break backwards compatibility for the C++ API. There is no implicit conversion from `ArrayRef<Tensor>` (which was the old argument type) to `List<optional<Tensor>>`. One common call pattern is `tensor.index({indices_tensor})`, where indices_tensor is another `Tensor`, and that will continue working because the `{}` initializer_list constructor for `List<optional<Tensor>>` can take `Tensor` elements that are implicitly converted to `optional<Tensor>`, but another common call pattern was `tensor.index(indices_tensor)`, where previously, the `Tensor` got implicitly converted to an `ArrayRef<Tensor>`, and to implicitly convert `Tensor -> optional<Tensor> -> List<optional<Tensor>>` would be two implicit conversions. C++ doesn't allow chaining. two implicit conversions. So those call sites have to be rewritten to `tensor.index({indices_tensor})`.

ghstack-source-id: 119269131

Test Plan:
## Benchmarks (C++ instruction counts):
### Forward
#### Script
```py
from torch.utils.benchmark import Timer

counts = Timer(
    stmt="""
        auto t = {{op call to measure}};
    """,
    setup="""
        using namespace torch::indexing;
        auto x = torch::ones({4, 4, 4});
    """,
    language="cpp",
).collect_callgrind(number=1_000)
print(counts)
```
#### Results
|  Op call                                                              |before   |after   |delta  |      |
|------------------------------------------------------------------------|---------|--------|-------|------|
|x[0] = 1                                                                |11566015 |11566015|0      |0.00% |
|x.index({0})                                                            |6807019  |6801019 |-6000  |-0.09%|
|x.index({0, 0})                                                         |13529019 |13557019|28000  |0.21% |
|x.index({0, 0, 0})                                                      |10677004 |10692004|15000  |0.14% |
|x.index({"..."})                                                        |5512015  |5506015 |-6000  |-0.11%|
|x.index({Slice(None, None, None)})                                      |6866016  |6936016 |70000  |1.02% |
|x.index({None})                                                         |8554015  |8548015 |-6000  |-0.07%|
|x.index({false})                                                        |22400000 |22744000|344000 |1.54% |
|x.index({true})                                                         |27624088 |27264393|-359695|-1.30%|
|x.index({"...", 0, true, Slice(1, None, 2), torch::tensor({1, 2})})|123472000|123463306|-8694|-0.01%|

### Autograd
#### Script
```py
from torch.utils.benchmark import Timer

counts = Timer(
    stmt="""
        auto t = {{op call to measure}};
    """,
    setup="""
        using namespace torch::indexing;
        auto x = torch::ones({4, 4, 4}, torch::requires_grad());
    """,
    language="cpp",
).collect_callgrind(number=1_000)
print(counts)
```
Note: the script measures the **forward** path of an op call with autograd enabled (i.e. calls into VariableType). It does not measure the backward path.

#### Results
|  Op call                                                              |before   |after   |delta  |      |
|------------------------------------------------------------------------|---------|--------|-------|------|
|x.index({0})                                                            |14839019|14833019|-6000| 0.00% |
|x.index({0, 0})                                                         |28342019|28370019|28000| 0.00% |
|x.index({0, 0, 0})                                                      |24434004|24449004|15000| 0.00% |
|x.index({"..."})                                                       |12773015|12767015|-6000| 0.00% |
|x.index({Slice(None, None, None)})                                      |14837016|14907016|70000| 0.47% |
|x.index({None})                                                        |15926015|15920015|-6000| 0.00% |
|x.index({false})                                                        |36958000|37477000|519000| 1.40% |
|x.index({true})                                                         |41971408|42426094|454686| 1.08% |
|x.index({"...", 0, true, Slice(1, None, 2), torch::tensor({1, 2})}) |168184392|164545682|-3638710| -2.16% |

Reviewed By: bhosmer

Differential Revision: D25454632

fbshipit-source-id: 28ab0cffbbdbdff1c40b4130ca62ee72f981b76d
  • Loading branch information
smessmer authored and facebook-github-bot committed Jan 4, 2021
1 parent e44b2b7 commit c7e9abb
Show file tree
Hide file tree
Showing 45 changed files with 510 additions and 305 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/ATen.h
Expand Up @@ -31,3 +31,4 @@
#include <c10/util/Exception.h>
#include <ATen/core/UnsafeFromTH.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>
1 change: 1 addition & 0 deletions aten/src/ATen/ParallelOpenMP.cpp
@@ -1,4 +1,5 @@
#include <ATen/Config.h>
#include <ATen/core/jit_type.h>
#if AT_PARALLEL_OPENMP
#include <ATen/Parallel.h>

Expand Down
11 changes: 7 additions & 4 deletions aten/src/ATen/TensorIndexing.h
Expand Up @@ -10,6 +10,8 @@
// There is some back story, see https://github.com/pytorch/pytorch/issues/48684
#include <ATen/NativeFunctions.h>

#include <ATen/core/List.h>

namespace at {
namespace indexing {

Expand Down Expand Up @@ -261,14 +263,15 @@ static inline void recordTensorIndex(const Tensor& tensor, std::vector<Tensor>&
(*dim_ptr)++;
};

static inline std::vector<Tensor> typeConvertIndices(const Tensor& self, std::vector<Tensor>&& indices) {
std::vector<Tensor> converted_inds(indices.size());
static inline c10::List<c10::optional<Tensor>> typeConvertIndices(const Tensor& self, std::vector<Tensor>&& indices) {
c10::List<c10::optional<Tensor>> converted_inds;
converted_inds.reserve(indices.size());
for (size_t i = 0; i < indices.size(); ++i) {
const auto &ind = indices[i];
if (ind.defined()) {
converted_inds[i] = ind.to(ind.options().device(self.device()));
converted_inds.push_back(ind.to(ind.options().device(self.device())));
} else {
converted_inds[i] = std::move(indices[i]);
converted_inds.push_back(std::move(indices[i]));
}
}
return converted_inds;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/autocast_mode.cpp
Expand Up @@ -406,7 +406,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
KERNEL(ADD_NS(cross), "cross", Tensor (const Tensor &, const Tensor &, c10::optional<int64_t>), promote)
KERNEL(ADD_NS(dot), "dot", Tensor (const Tensor &, const Tensor &), promote)
KERNEL(ADD_NS(equal), "equal", bool (const Tensor &, const Tensor &), promote)
KERNEL_UNBOXED_ONLY(ADD_NS(index_put), "index_put", Tensor (const Tensor &, TensorList, const Tensor &, bool), promote)
KERNEL(ADD_NS(index_put), "index_put", Tensor (const Tensor &, const torch::List<c10::optional<Tensor>>&, const Tensor &, bool), promote)
KERNEL(ADD_NS(stack), "stack", Tensor (TensorList, int64_t), promote)
KERNEL(ADD_NS(tensordot), "tensordot", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef), promote)

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/core/List.h
Expand Up @@ -243,7 +243,7 @@ class List final {
* Example:
* List<int> a({2, 3, 4});
*/
explicit List(std::initializer_list<T> initial_values);
List(std::initializer_list<T> initial_values);
explicit List(ArrayRef<T> initial_values);

/**
Expand Down
16 changes: 14 additions & 2 deletions aten/src/ATen/core/List_inl.h
@@ -1,7 +1,7 @@
#pragma once

#include <ATen/core/jit_type_base.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>

namespace c10 {

Expand Down Expand Up @@ -50,7 +50,17 @@ List<T>::List(TypePtr elementType)
namespace impl {
template<class T>
List<T> toTypedList(impl::GenericList list) {
TORCH_INTERNAL_ASSERT(*getTypePtr<T>() == *list.impl_->elementType, "Tried to cast a List<", toString(list.impl_->elementType), "> to a List<", toString(getTypePtr<T>()), ">. Types mismatch.");
// If there's other instances of the list (i.e. list.use_count() > 1), then we have to be invariant
// because upcasting would allow people to add types into the new list that would break the old list.
// However, if there aren't any other instances of this list (i.e. list.use_count() == 1), then we can
// allow upcasting. This can be a perf improvement since we can cast List<T> to List<optional<T>>
// without having to copy it. This is also used to provide backwards compatibility with some old models
// that serialized the index arguments to aten::index, aten::index_put, aten::index_put_ and aten::index_put_impl_
// as List<Tensor> before we changed that argument to be List<optional<Tensor>>. When deserializing, we
// have list.use_count() == 1 and can deserialize the List<Tensor> directly as List<optional<Tensor>>.
TORCH_CHECK(*list.impl_->elementType == *getTypePtr<T>()
|| (list.use_count() == 1 && list.impl_->elementType->isSubtypeOf(getTypePtr<T>()))
, "Tried to cast a List<", toString(list.impl_->elementType), "> to a List<", toString(getTypePtr<T>()), ">. Types mismatch.");
return List<T>(std::move(list.impl_));
}

Expand Down Expand Up @@ -312,3 +322,5 @@ void List<T>::unsafeSetElementType(TypePtr t) {
impl_->elementType = std::move(t);
}
}

#include <ATen/core/jit_type.h>
10 changes: 10 additions & 0 deletions aten/src/ATen/core/Variadic.h
Expand Up @@ -6,6 +6,7 @@
#include <utility>

#include <c10/util/ArrayRef.h>
#include <ATen/core/List.h>

namespace at {

Expand Down Expand Up @@ -56,6 +57,15 @@ struct IterArgs {
}
}

template <typename T>
void operator()(const torch::List<T>& args) {
for (const auto& arg : args) {
self()(arg);
if (self().short_circuit())
return;
}
}

// NB: we need to specify std::vector manually as C++ won't
// do an implicit conversion to make a template deduction go through.
template <typename T>
Expand Down
187 changes: 4 additions & 183 deletions aten/src/ATen/core/jit_type.h
@@ -1,10 +1,11 @@
#pragma once

#include <ATen/core/jit_type_base.h>
#include <ATen/core/TensorBody.h>
#include <ATen/core/functional.h>
#include <ATen/core/interned_strings.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/qualified_name.h>
#include <ATen/core/ivalue.h>
#include <c10/util/TypeList.h>
#include <c10/util/Optional.h>

Expand All @@ -17,197 +18,17 @@ struct ClassType;
namespace torch {
namespace jit {
struct CompilationUnit;
struct Function;
} // namespace jit
} // namespace torch

namespace c10 {

struct IValue;
struct FunctionSchema;
struct NamedType;
using OptNameList = c10::optional<std::vector<std::string>>;

#define C10_FORALL_TYPES(_) \
_(AnyType) \
_(EnumType) \
_(AnyEnumType) \
_(TensorType) \
_(StorageType) \
_(TupleType) \
_(ListType) \
_(DictType) \
_(NumberType) \
_(FloatType) \
_(FutureType) \
_(RRefType) \
_(IntType) \
_(NoneType) \
_(StringType) \
_(GeneratorType) \
_(QuantizerType) \
_(BoolType) \
_(OptionalType) \
_(VarType) \
_(DeviceObjType) \
_(StreamObjType) \
_(FunctionType) \
_(ClassType) \
_(PyObjectType) \
_(CapsuleType) \
_(InterfaceType) \
_(QSchemeType) \
_(LayoutType) \
_(ScalarTypeType) \
_(AnyListType) \
_(AnyTupleType) \
_(AnyClassType)

enum class TypeKind {
#define DEFINE_TYPE(T) T,
C10_FORALL_TYPES(DEFINE_TYPE)
#undef DEFINE_TYPE
};

TORCH_API const char* typeKindToString(TypeKind kind);

struct Type;
using TypePtr = std::shared_ptr<Type>;
using ConstTypePtr = std::shared_ptr<const Type>;

// Use this to customize how a Type is printed using `annotation_str()`. If
// c10::nullopt is returned, `annotation_str()` falls through to its default
// implementation.
using TypePrinter =
std::function<c10::optional<std::string>(const ConstTypePtr&)>;

struct TORCH_API Type : std::enable_shared_from_this<Type> {
private:
TypeKind kind_;

protected:
Type(TypeKind kind) : kind_(kind) {}

virtual std::string annotation_str_impl(TypePrinter printer) const {
return str();
}

public:
virtual bool operator==(const Type& rhs) const = 0;

// subtyping relation. By default, we return true for the case
// when the type is exactly equal or if this <: T where rhs = Optional[T]

// if this returns false and the why_not stream is non-null, it contains
// additional details that describe why this is not a subtype of 'rhs'.
// This additional information should only contain details that are not obvious
// from the annotation_str() that describes the type. For instance it is clear that `int <: str` is false
// but not clear why `Foo <: InterfaceBar` might be false.
virtual bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const;
virtual bool is_module() const;
bool isSubtypeOf(const TypePtr& rhs) const {
return isSubtypeOfExt(rhs, nullptr);
}

// How this type will appear in FunctionSchema declarations
virtual std::string str() const = 0;

// How this type will appear as if it were a type annotation in Python
// which is sometimes different than how it appears in declarations (e.g.
// int[] vs List[int])
//
// Takes a custom printer that users can pass in to customize the output of
// this method.
std::string annotation_str(TypePrinter printer) const {
if (printer) {
// the printer can return nullopt to fall through to the default impl
if (auto renamed = printer(shared_from_this())) {
return *renamed;
}
}
return annotation_str_impl(printer);
}
std::string annotation_str() const {
// Overload instead of define a default value for `printer` to help
// debuggers out.
return annotation_str(nullptr);
}

// Returns a human readable string that includes additional information like
// "type is inferred rather than explictly defined" to help construct more
// user-friendly messages.
virtual std::string repr_str() const {
return annotation_str();
}

TypeKind kind() const {
return kind_;
}

virtual bool requires_grad() const {
for (const auto& ct : containedTypes()) {
if (ct->requires_grad()) {
return true;
}
}
return false;
}

// Dynamically cast this object to the subclass indicated by the
// template variable, returning nullptr if the cast is invalid.
template <typename T>
std::shared_ptr<T> cast() {
if (T::Kind == kind()) {
return std::static_pointer_cast<T>(shared_from_this());
}
return nullptr;
}
template <typename T>
std::shared_ptr<const T> cast() const {
if (T::Kind == kind()) {
return std::static_pointer_cast<const T>(shared_from_this());
}
return nullptr;
}
template <typename T>
std::shared_ptr<T> expect() {
auto r = cast<T>();
AT_ASSERT(r);
return r;
}
template <typename T>
std::shared_ptr<const T> expect() const {
auto r = cast<const T>();
AT_ASSERT(r);
return r;
}
virtual ~Type() = default;
virtual bool hasFreeVariables() const {
return false;
}
// list of types this type contains, e.g. for a List then element type of a
// list for a tuple, the types of the tuple elements
virtual at::ArrayRef<TypePtr> containedTypes() const {
return {};
}
// create a new version of this type, replacing its contained types with
// contained_types
TypePtr withContained(std::vector<TypePtr> contained_types) {
auto current_contained = containedTypes();
AT_ASSERT(current_contained.size() == contained_types.size());
if (current_contained.equals(contained_types)) {
return shared_from_this();
}
return createWithContained(std::move(contained_types));
}
// per-type constructor, you only need to override this if the
// containedTypes() is not empty
virtual TypePtr createWithContained(
std::vector<TypePtr> contained_types) const {
AT_ERROR(
"type with contained types did not overload createWithContained: ",
str());
}
};

struct AnyType;
using AnyTypePtr = std::shared_ptr<AnyType>;
// Any is the top of the type hierarchy, all other types are subtypes
Expand Down

0 comments on commit c7e9abb

Please sign in to comment.