Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making ops c10-full: list of optional tensors #49138

Closed
wants to merge 50 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
1030266
[wip] Making ops c10-full: list of optional tensors
smessmer Dec 10, 2020
f38b32c
Update on "[wip] Making ops c10-full: list of optional tensors"
smessmer Dec 10, 2020
038c206
Update on "[wip] Making ops c10-full: list of optional tensors"
smessmer Dec 10, 2020
ac4a088
Update on "[wip] Making ops c10-full: list of optional tensors"
smessmer Dec 10, 2020
9dff8b2
Update on "[wip] Making ops c10-full: list of optional tensors"
smessmer Dec 10, 2020
b4eb389
Update on "[wip] Making ops c10-full: list of optional tensors"
smessmer Dec 10, 2020
39666cb
Update on "[wip] Making ops c10-full: list of optional tensors"
smessmer Dec 10, 2020
a4647de
Update on "[wip] Making ops c10-full: list of optional tensors"
smessmer Dec 10, 2020
3851dfb
Update on "[wip] Making ops c10-full: list of optional tensors"
smessmer Dec 10, 2020
2e9d9e0
Update on "[wip] Making ops c10-full: list of optional tensors"
smessmer Dec 10, 2020
62ea7aa
Update on "[wip] Making ops c10-full: list of optional tensors"
smessmer Dec 10, 2020
b119bf0
Update on "[wip] Making ops c10-full: list of optional tensors"
smessmer Dec 10, 2020
9306c6f
Update on "[wip] Making ops c10-full: list of optional tensors"
smessmer Dec 11, 2020
3b399a2
Update on "[wip] Making ops c10-full: list of optional tensors"
smessmer Dec 11, 2020
e498c38
Update on "[wip] Making ops c10-full: list of optional tensors"
smessmer Dec 11, 2020
10e7707
Update on "[wip] Making ops c10-full: list of optional tensors"
smessmer Dec 11, 2020
6abc660
Update on "[wip] Making ops c10-full: list of optional tensors"
smessmer Dec 11, 2020
f94de1d
Update on "[wip] Making ops c10-full: list of optional tensors"
smessmer Dec 11, 2020
c85436f
Update on "[wip] Making ops c10-full: list of optional tensors"
smessmer Dec 11, 2020
9705808
Update on "[wip] Making ops c10-full: list of optional tensors"
smessmer Dec 11, 2020
7e2453a
Update on "[wip] Making ops c10-full: list of optional tensors"
smessmer Dec 11, 2020
7619b58
Update on "[wip] Making ops c10-full: list of optional tensors"
smessmer Dec 11, 2020
11e2931
Update on "[wip] Making ops c10-full: list of optional tensors"
smessmer Dec 11, 2020
adc572c
Update on "[wip] Making ops c10-full: list of optional tensors"
smessmer Dec 12, 2020
802679b
Update on "[wip] Making ops c10-full: list of optional tensors"
smessmer Dec 15, 2020
4e25799
Update on "[wip] Making ops c10-full: list of optional tensors"
smessmer Dec 15, 2020
32c02e1
Update on "[wip] Making ops c10-full: list of optional tensors"
smessmer Dec 15, 2020
01c0b3d
Update on "[wip] Making ops c10-full: list of optional tensors"
smessmer Dec 16, 2020
365b6f4
Update on "[wip] Making ops c10-full: list of optional tensors"
smessmer Dec 18, 2020
8f4d16f
Update on "[wip] Making ops c10-full: list of optional tensors"
smessmer Dec 18, 2020
d7d4b48
Update on "[wip] Making ops c10-full: list of optional tensors"
smessmer Dec 18, 2020
7df98eb
Update on "Making ops c10-full: list of optional tensors"
smessmer Dec 18, 2020
0a07e13
Update on "Making ops c10-full: list of optional tensors"
smessmer Dec 18, 2020
3aa958c
Update on "Making ops c10-full: list of optional tensors"
smessmer Dec 18, 2020
1b12ff1
Update on "Making ops c10-full: list of optional tensors"
smessmer Dec 21, 2020
c48901d
Update on "Making ops c10-full: list of optional tensors"
smessmer Dec 21, 2020
539cfd1
Update on "Making ops c10-full: list of optional tensors"
smessmer Dec 21, 2020
d385f8b
Update on "Making ops c10-full: list of optional tensors"
smessmer Dec 21, 2020
e8bd651
Update on "Making ops c10-full: list of optional tensors"
smessmer Dec 21, 2020
709f85f
Update on "Making ops c10-full: list of optional tensors"
smessmer Dec 22, 2020
0bb686e
Update on "Making ops c10-full: list of optional tensors"
smessmer Dec 22, 2020
2128b60
Update on "Making ops c10-full: list of optional tensors"
smessmer Dec 22, 2020
9fe98ce
Update on "Making ops c10-full: list of optional tensors"
smessmer Dec 22, 2020
4bd1339
Update on "Making ops c10-full: list of optional tensors"
smessmer Dec 22, 2020
618c93c
Update on "Making ops c10-full: list of optional tensors"
smessmer Dec 22, 2020
b3bbb6c
Update on "Making ops c10-full: list of optional tensors"
smessmer Dec 23, 2020
735cedd
Update on "Making ops c10-full: list of optional tensors"
smessmer Dec 23, 2020
4c493cd
Update on "Making ops c10-full: list of optional tensors"
smessmer Dec 23, 2020
1ce09f1
Update on "Making ops c10-full: list of optional tensors"
smessmer Dec 29, 2020
11d8311
Update on "Making ops c10-full: list of optional tensors"
smessmer Jan 3, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions aten/src/ATen/ATen.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@
#include <c10/util/Exception.h>
#include <ATen/core/UnsafeFromTH.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>
1 change: 1 addition & 0 deletions aten/src/ATen/ParallelOpenMP.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <ATen/Config.h>
#include <ATen/core/jit_type.h>
#if AT_PARALLEL_OPENMP
#include <ATen/Parallel.h>

Expand Down
11 changes: 7 additions & 4 deletions aten/src/ATen/TensorIndexing.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
// There is some back story, see https://github.com/pytorch/pytorch/issues/48684
#include <ATen/NativeFunctions.h>

#include <ATen/core/List.h>

namespace at {
namespace indexing {

Expand Down Expand Up @@ -261,14 +263,15 @@ static inline void recordTensorIndex(const Tensor& tensor, std::vector<Tensor>&
(*dim_ptr)++;
};

static inline std::vector<Tensor> typeConvertIndices(const Tensor& self, std::vector<Tensor>&& indices) {
std::vector<Tensor> converted_inds(indices.size());
static inline c10::List<c10::optional<Tensor>> typeConvertIndices(const Tensor& self, std::vector<Tensor>&& indices) {
c10::List<c10::optional<Tensor>> converted_inds;
converted_inds.reserve(indices.size());
for (size_t i = 0; i < indices.size(); ++i) {
const auto &ind = indices[i];
if (ind.defined()) {
converted_inds[i] = ind.to(ind.options().device(self.device()));
converted_inds.push_back(ind.to(ind.options().device(self.device())));
} else {
converted_inds[i] = std::move(indices[i]);
converted_inds.push_back(std::move(indices[i]));
}
}
return converted_inds;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/autocast_mode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
KERNEL(ADD_NS(cross), "cross", Tensor (const Tensor &, const Tensor &, c10::optional<int64_t>), promote)
KERNEL(ADD_NS(dot), "dot", Tensor (const Tensor &, const Tensor &), promote)
KERNEL(ADD_NS(equal), "equal", bool (const Tensor &, const Tensor &), promote)
KERNEL_UNBOXED_ONLY(ADD_NS(index_put), "index_put", Tensor (const Tensor &, TensorList, const Tensor &, bool), promote)
KERNEL(ADD_NS(index_put), "index_put", Tensor (const Tensor &, const torch::List<c10::optional<Tensor>>&, const Tensor &, bool), promote)
KERNEL(ADD_NS(stack), "stack", Tensor (TensorList, int64_t), promote)
KERNEL(ADD_NS(tensordot), "tensordot", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef), promote)

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/core/List.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ class List final {
* Example:
* List<int> a({2, 3, 4});
*/
explicit List(std::initializer_list<T> initial_values);
List(std::initializer_list<T> initial_values);
explicit List(ArrayRef<T> initial_values);

/**
Expand Down
16 changes: 14 additions & 2 deletions aten/src/ATen/core/List_inl.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma once

#include <ATen/core/jit_type_base.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>

namespace c10 {

Expand Down Expand Up @@ -50,7 +50,17 @@ List<T>::List(TypePtr elementType)
namespace impl {
template<class T>
List<T> toTypedList(impl::GenericList list) {
TORCH_INTERNAL_ASSERT(*getTypePtr<T>() == *list.impl_->elementType, "Tried to cast a List<", toString(list.impl_->elementType), "> to a List<", toString(getTypePtr<T>()), ">. Types mismatch.");
// If there's other instances of the list (i.e. list.use_count() > 1), then we have to be invariant
// because upcasting would allow people to add types into the new list that would break the old list.
// However, if there aren't any other instances of this list (i.e. list.use_count() == 1), then we can
// allow upcasting. This can be a perf improvement since we can cast List<T> to List<optional<T>>
// without having to copy it. This is also used to provide backwards compatibility with some old models
// that serialized the index arguments to aten::index, aten::index_put, aten::index_put_ and aten::index_put_impl_
// as List<Tensor> before we changed that argument to be List<optional<Tensor>>. When deserializing, we
// have list.use_count() == 1 and can deserialize the List<Tensor> directly as List<optional<Tensor>>.
TORCH_CHECK(*list.impl_->elementType == *getTypePtr<T>()
|| (list.use_count() == 1 && list.impl_->elementType->isSubtypeOf(getTypePtr<T>()))
, "Tried to cast a List<", toString(list.impl_->elementType), "> to a List<", toString(getTypePtr<T>()), ">. Types mismatch.");
return List<T>(std::move(list.impl_));
}

Expand Down Expand Up @@ -312,3 +322,5 @@ void List<T>::unsafeSetElementType(TypePtr t) {
impl_->elementType = std::move(t);
}
}

#include <ATen/core/jit_type.h>
10 changes: 10 additions & 0 deletions aten/src/ATen/core/Variadic.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <utility>

#include <c10/util/ArrayRef.h>
#include <ATen/core/List.h>

namespace at {

Expand Down Expand Up @@ -56,6 +57,15 @@ struct IterArgs {
}
}

template <typename T>
void operator()(const torch::List<T>& args) {
for (const auto& arg : args) {
self()(arg);
if (self().short_circuit())
return;
}
}

// NB: we need to specify std::vector manually as C++ won't
// do an implicit conversion to make a template deduction go through.
template <typename T>
Expand Down
187 changes: 4 additions & 183 deletions aten/src/ATen/core/jit_type.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#pragma once

#include <ATen/core/jit_type_base.h>
#include <ATen/core/TensorBody.h>
#include <ATen/core/functional.h>
#include <ATen/core/interned_strings.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/qualified_name.h>
#include <ATen/core/ivalue.h>
#include <c10/util/TypeList.h>
#include <c10/util/Optional.h>

Expand All @@ -17,197 +18,17 @@ struct ClassType;
namespace torch {
namespace jit {
struct CompilationUnit;
struct Function;
} // namespace jit
} // namespace torch

namespace c10 {

struct IValue;
struct FunctionSchema;
struct NamedType;
using OptNameList = c10::optional<std::vector<std::string>>;

#define C10_FORALL_TYPES(_) \
_(AnyType) \
_(EnumType) \
_(AnyEnumType) \
_(TensorType) \
_(StorageType) \
_(TupleType) \
_(ListType) \
_(DictType) \
_(NumberType) \
_(FloatType) \
_(FutureType) \
_(RRefType) \
_(IntType) \
_(NoneType) \
_(StringType) \
_(GeneratorType) \
_(QuantizerType) \
_(BoolType) \
_(OptionalType) \
_(VarType) \
_(DeviceObjType) \
_(StreamObjType) \
_(FunctionType) \
_(ClassType) \
_(PyObjectType) \
_(CapsuleType) \
_(InterfaceType) \
_(QSchemeType) \
_(LayoutType) \
_(ScalarTypeType) \
_(AnyListType) \
_(AnyTupleType) \
_(AnyClassType)

enum class TypeKind {
#define DEFINE_TYPE(T) T,
C10_FORALL_TYPES(DEFINE_TYPE)
#undef DEFINE_TYPE
};

TORCH_API const char* typeKindToString(TypeKind kind);

struct Type;
using TypePtr = std::shared_ptr<Type>;
using ConstTypePtr = std::shared_ptr<const Type>;

// Use this to customize how a Type is printed using `annotation_str()`. If
// c10::nullopt is returned, `annotation_str()` falls through to its default
// implementation.
using TypePrinter =
std::function<c10::optional<std::string>(const ConstTypePtr&)>;

struct TORCH_API Type : std::enable_shared_from_this<Type> {
private:
TypeKind kind_;

protected:
Type(TypeKind kind) : kind_(kind) {}

virtual std::string annotation_str_impl(TypePrinter printer) const {
return str();
}

public:
virtual bool operator==(const Type& rhs) const = 0;

// subtyping relation. By default, we return true for the case
// when the type is exactly equal or if this <: T where rhs = Optional[T]

// if this returns false and the why_not stream is non-null, it contains
// additional details that describe why this is not a subtype of 'rhs'.
// This additional information should only contain details that are not obvious
// from the annotation_str() that describes the type. For instance it is clear that `int <: str` is false
// but not clear why `Foo <: InterfaceBar` might be false.
virtual bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const;
virtual bool is_module() const;
bool isSubtypeOf(const TypePtr& rhs) const {
return isSubtypeOfExt(rhs, nullptr);
}

// How this type will appear in FunctionSchema declarations
virtual std::string str() const = 0;

// How this type will appear as if it were a type annotation in Python
// which is sometimes different than how it appears in declarations (e.g.
// int[] vs List[int])
//
// Takes a custom printer that users can pass in to customize the output of
// this method.
std::string annotation_str(TypePrinter printer) const {
if (printer) {
// the printer can return nullopt to fall through to the default impl
if (auto renamed = printer(shared_from_this())) {
return *renamed;
}
}
return annotation_str_impl(printer);
}
std::string annotation_str() const {
// Overload instead of define a default value for `printer` to help
// debuggers out.
return annotation_str(nullptr);
}

// Returns a human readable string that includes additional information like
// "type is inferred rather than explictly defined" to help construct more
// user-friendly messages.
virtual std::string repr_str() const {
return annotation_str();
}

TypeKind kind() const {
return kind_;
}

virtual bool requires_grad() const {
for (const auto& ct : containedTypes()) {
if (ct->requires_grad()) {
return true;
}
}
return false;
}

// Dynamically cast this object to the subclass indicated by the
// template variable, returning nullptr if the cast is invalid.
template <typename T>
std::shared_ptr<T> cast() {
if (T::Kind == kind()) {
return std::static_pointer_cast<T>(shared_from_this());
}
return nullptr;
}
template <typename T>
std::shared_ptr<const T> cast() const {
if (T::Kind == kind()) {
return std::static_pointer_cast<const T>(shared_from_this());
}
return nullptr;
}
template <typename T>
std::shared_ptr<T> expect() {
auto r = cast<T>();
AT_ASSERT(r);
return r;
}
template <typename T>
std::shared_ptr<const T> expect() const {
auto r = cast<const T>();
AT_ASSERT(r);
return r;
}
virtual ~Type() = default;
virtual bool hasFreeVariables() const {
return false;
}
// list of types this type contains, e.g. for a List then element type of a
// list for a tuple, the types of the tuple elements
virtual at::ArrayRef<TypePtr> containedTypes() const {
return {};
}
// create a new version of this type, replacing its contained types with
// contained_types
TypePtr withContained(std::vector<TypePtr> contained_types) {
auto current_contained = containedTypes();
AT_ASSERT(current_contained.size() == contained_types.size());
if (current_contained.equals(contained_types)) {
return shared_from_this();
}
return createWithContained(std::move(contained_types));
}
// per-type constructor, you only need to override this if the
// containedTypes() is not empty
virtual TypePtr createWithContained(
std::vector<TypePtr> contained_types) const {
AT_ERROR(
"type with contained types did not overload createWithContained: ",
str());
}
};

struct AnyType;
using AnyTypePtr = std::shared_ptr<AnyType>;
// Any is the top of the type hierarchy, all other types are subtypes
Expand Down