Skip to content

Commit

Permalink
Update on "Fix store based barrier to only use 'add'."
Browse files Browse the repository at this point in the history
Certain store implementations don't work well when we use get() and
add() on the same key. To avoid this issue, we only use add() in the store
based barrier. The buggy store implementations can't be properly fixed due to
legacy reasons.

Differential Revision: [D25725386](https://our.internmc.facebook.com/intern/diff/D25725386/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D25725386/)!

[ghstack-poisoned]
  • Loading branch information
pritamdamania committed Jan 4, 2021
2 parents 9405faf + c7e9abb commit e0edfca
Show file tree
Hide file tree
Showing 166 changed files with 2,607 additions and 787 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/ATen.h
Expand Up @@ -31,3 +31,4 @@
#include <c10/util/Exception.h>
#include <ATen/core/UnsafeFromTH.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>
1 change: 1 addition & 0 deletions aten/src/ATen/ParallelOpenMP.cpp
@@ -1,4 +1,5 @@
#include <ATen/Config.h>
#include <ATen/core/jit_type.h>
#if AT_PARALLEL_OPENMP
#include <ATen/Parallel.h>

Expand Down
11 changes: 7 additions & 4 deletions aten/src/ATen/TensorIndexing.h
Expand Up @@ -10,6 +10,8 @@
// There is some back story, see https://github.com/pytorch/pytorch/issues/48684
#include <ATen/NativeFunctions.h>

#include <ATen/core/List.h>

namespace at {
namespace indexing {

Expand Down Expand Up @@ -261,14 +263,15 @@ static inline void recordTensorIndex(const Tensor& tensor, std::vector<Tensor>&
(*dim_ptr)++;
};

static inline std::vector<Tensor> typeConvertIndices(const Tensor& self, std::vector<Tensor>&& indices) {
std::vector<Tensor> converted_inds(indices.size());
static inline c10::List<c10::optional<Tensor>> typeConvertIndices(const Tensor& self, std::vector<Tensor>&& indices) {
c10::List<c10::optional<Tensor>> converted_inds;
converted_inds.reserve(indices.size());
for (size_t i = 0; i < indices.size(); ++i) {
const auto &ind = indices[i];
if (ind.defined()) {
converted_inds[i] = ind.to(ind.options().device(self.device()));
converted_inds.push_back(ind.to(ind.options().device(self.device())));
} else {
converted_inds[i] = std::move(indices[i]);
converted_inds.push_back(std::move(indices[i]));
}
}
return converted_inds;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/autocast_mode.cpp
Expand Up @@ -406,7 +406,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
KERNEL(ADD_NS(cross), "cross", Tensor (const Tensor &, const Tensor &, c10::optional<int64_t>), promote)
KERNEL(ADD_NS(dot), "dot", Tensor (const Tensor &, const Tensor &), promote)
KERNEL(ADD_NS(equal), "equal", bool (const Tensor &, const Tensor &), promote)
KERNEL_UNBOXED_ONLY(ADD_NS(index_put), "index_put", Tensor (const Tensor &, TensorList, const Tensor &, bool), promote)
KERNEL(ADD_NS(index_put), "index_put", Tensor (const Tensor &, const torch::List<c10::optional<Tensor>>&, const Tensor &, bool), promote)
KERNEL(ADD_NS(stack), "stack", Tensor (TensorList, int64_t), promote)
KERNEL(ADD_NS(tensordot), "tensordot", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef), promote)

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/core/List.h
Expand Up @@ -243,7 +243,7 @@ class List final {
* Example:
* List<int> a({2, 3, 4});
*/
explicit List(std::initializer_list<T> initial_values);
List(std::initializer_list<T> initial_values);
explicit List(ArrayRef<T> initial_values);

/**
Expand Down
16 changes: 14 additions & 2 deletions aten/src/ATen/core/List_inl.h
@@ -1,7 +1,7 @@
#pragma once

#include <ATen/core/jit_type_base.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>

namespace c10 {

Expand Down Expand Up @@ -50,7 +50,17 @@ List<T>::List(TypePtr elementType)
namespace impl {
template<class T>
List<T> toTypedList(impl::GenericList list) {
TORCH_INTERNAL_ASSERT(*getTypePtr<T>() == *list.impl_->elementType, "Tried to cast a List<", toString(list.impl_->elementType), "> to a List<", toString(getTypePtr<T>()), ">. Types mismatch.");
// If there's other instances of the list (i.e. list.use_count() > 1), then we have to be invariant
// because upcasting would allow people to add types into the new list that would break the old list.
// However, if there aren't any other instances of this list (i.e. list.use_count() == 1), then we can
// allow upcasting. This can be a perf improvement since we can cast List<T> to List<optional<T>>
// without having to copy it. This is also used to provide backwards compatibility with some old models
// that serialized the index arguments to aten::index, aten::index_put, aten::index_put_ and aten::index_put_impl_
// as List<Tensor> before we changed that argument to be List<optional<Tensor>>. When deserializing, we
// have list.use_count() == 1 and can deserialize the List<Tensor> directly as List<optional<Tensor>>.
TORCH_CHECK(*list.impl_->elementType == *getTypePtr<T>()
|| (list.use_count() == 1 && list.impl_->elementType->isSubtypeOf(getTypePtr<T>()))
, "Tried to cast a List<", toString(list.impl_->elementType), "> to a List<", toString(getTypePtr<T>()), ">. Types mismatch.");
return List<T>(std::move(list.impl_));
}

Expand Down Expand Up @@ -312,3 +322,5 @@ void List<T>::unsafeSetElementType(TypePtr t) {
impl_->elementType = std::move(t);
}
}

#include <ATen/core/jit_type.h>
10 changes: 10 additions & 0 deletions aten/src/ATen/core/Variadic.h
Expand Up @@ -6,6 +6,7 @@
#include <utility>

#include <c10/util/ArrayRef.h>
#include <ATen/core/List.h>

namespace at {

Expand Down Expand Up @@ -56,6 +57,15 @@ struct IterArgs {
}
}

template <typename T>
void operator()(const torch::List<T>& args) {
for (const auto& arg : args) {
self()(arg);
if (self().short_circuit())
return;
}
}

// NB: we need to specify std::vector manually as C++ won't
// do an implicit conversion to make a template deduction go through.
template <typename T>
Expand Down
9 changes: 9 additions & 0 deletions aten/src/ATen/core/interned_strings.h
Expand Up @@ -17,6 +17,7 @@ namespace c10 {
#define FORALL_NS_SYMBOLS(_) \
_(namespaces, prim) \
_(namespaces, aten) \
_(namespaces, cuda) \
_(namespaces, onnx) \
_(namespaces, attr) \
_(namespaces, scope) \
Expand Down Expand Up @@ -284,6 +285,9 @@ namespace c10 {
_(aten, zero_) \
_(aten, fill_) \
_(aten, masked_fill_) \
_(cuda, _set_device) \
_(cuda, set_stream) \
_(cuda, _current_device) \
_(aten, swapaxes) \
_(aten, swapaxes_) \
_(aten, swapdims) \
Expand Down Expand Up @@ -383,6 +387,7 @@ namespace c10 {
#define FORALL_NS_SYMBOLS(_) \
_(namespaces, prim) \
_(namespaces, aten) \
_(namespaces, cuda) \
_(namespaces, onnx) \
_(namespaces, attr) \
_(namespaces, scope) \
Expand Down Expand Up @@ -453,6 +458,7 @@ struct TORCH_API Symbol {
// (and if it's not, you should add it to the built-ins list above.)
static Symbol attr(const std::string & s);
static Symbol aten(const std::string & s);
static Symbol cuda(const std::string & s);
static Symbol onnx(const std::string & s);
static Symbol prim(const std::string & s);
static Symbol user(const std::string & s);
Expand All @@ -463,6 +469,7 @@ struct TORCH_API Symbol {

bool is_attr() const;
bool is_aten() const;
bool is_cuda() const;
bool is_prim() const;
bool is_onnx() const;
bool is_user() const;
Expand Down Expand Up @@ -523,6 +530,7 @@ FORALL_NS_SYMBOLS(DEFINE_SYMBOL)

inline Symbol Symbol::attr(const std::string & s) { return Symbol::fromQualString("attr::" + s); }
inline Symbol Symbol::aten(const std::string & s) { return Symbol::fromQualString("aten::" + s); }
inline Symbol Symbol::cuda(const std::string & s) { return Symbol::fromQualString("cuda::" + s); }
inline Symbol Symbol::onnx(const std::string & s) { return Symbol::fromQualString("onnx::" + s); }
inline Symbol Symbol::prim(const std::string & s) { return Symbol::fromQualString("prim::" + s); }
inline Symbol Symbol::scope(const std::string & s) { return Symbol::fromQualString("scope::" + s); }
Expand All @@ -531,6 +539,7 @@ inline Symbol Symbol::caffe2(const std::string & s) { return Symbol::fromQualStr
inline Symbol Symbol::dimname(const std::string & s) { return Symbol::fromQualString("dimname::" + s); }
inline bool Symbol::is_attr() const { return ns() == namespaces::attr; }
inline bool Symbol::is_aten() const { return ns() == namespaces::aten; }
inline bool Symbol::is_cuda() const { return ns() == namespaces::cuda; }
inline bool Symbol::is_prim() const { return ns() == namespaces::prim; }
inline bool Symbol::is_onnx() const { return ns() == namespaces::onnx; }
inline bool Symbol::is_user() const { return ns() == namespaces::user; }
Expand Down
187 changes: 4 additions & 183 deletions aten/src/ATen/core/jit_type.h
@@ -1,10 +1,11 @@
#pragma once

#include <ATen/core/jit_type_base.h>
#include <ATen/core/TensorBody.h>
#include <ATen/core/functional.h>
#include <ATen/core/interned_strings.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/qualified_name.h>
#include <ATen/core/ivalue.h>
#include <c10/util/TypeList.h>
#include <c10/util/Optional.h>

Expand All @@ -17,197 +18,17 @@ struct ClassType;
namespace torch {
namespace jit {
struct CompilationUnit;
struct Function;
} // namespace jit
} // namespace torch

namespace c10 {

struct IValue;
struct FunctionSchema;
struct NamedType;
using OptNameList = c10::optional<std::vector<std::string>>;

#define C10_FORALL_TYPES(_) \
_(AnyType) \
_(EnumType) \
_(AnyEnumType) \
_(TensorType) \
_(StorageType) \
_(TupleType) \
_(ListType) \
_(DictType) \
_(NumberType) \
_(FloatType) \
_(FutureType) \
_(RRefType) \
_(IntType) \
_(NoneType) \
_(StringType) \
_(GeneratorType) \
_(QuantizerType) \
_(BoolType) \
_(OptionalType) \
_(VarType) \
_(DeviceObjType) \
_(StreamObjType) \
_(FunctionType) \
_(ClassType) \
_(PyObjectType) \
_(CapsuleType) \
_(InterfaceType) \
_(QSchemeType) \
_(LayoutType) \
_(ScalarTypeType) \
_(AnyListType) \
_(AnyTupleType) \
_(AnyClassType)

enum class TypeKind {
#define DEFINE_TYPE(T) T,
C10_FORALL_TYPES(DEFINE_TYPE)
#undef DEFINE_TYPE
};

TORCH_API const char* typeKindToString(TypeKind kind);

struct Type;
using TypePtr = std::shared_ptr<Type>;
using ConstTypePtr = std::shared_ptr<const Type>;

// Use this to customize how a Type is printed using `annotation_str()`. If
// c10::nullopt is returned, `annotation_str()` falls through to its default
// implementation.
using TypePrinter =
std::function<c10::optional<std::string>(const ConstTypePtr&)>;

struct TORCH_API Type : std::enable_shared_from_this<Type> {
private:
TypeKind kind_;

protected:
Type(TypeKind kind) : kind_(kind) {}

virtual std::string annotation_str_impl(TypePrinter printer) const {
return str();
}

public:
virtual bool operator==(const Type& rhs) const = 0;

// subtyping relation. By default, we return true for the case
// when the type is exactly equal or if this <: T where rhs = Optional[T]

// if this returns false and the why_not stream is non-null, it contains
// additional details that describe why this is not a subtype of 'rhs'.
// This additional information should only contain details that are not obvious
// from the annotation_str() that describes the type. For instance it is clear that `int <: str` is false
// but not clear why `Foo <: InterfaceBar` might be false.
virtual bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const;
virtual bool is_module() const;
bool isSubtypeOf(const TypePtr& rhs) const {
return isSubtypeOfExt(rhs, nullptr);
}

// How this type will appear in FunctionSchema declarations
virtual std::string str() const = 0;

// How this type will appear as if it were a type annotation in Python
// which is sometimes different than how it appears in declarations (e.g.
// int[] vs List[int])
//
// Takes a custom printer that users can pass in to customize the output of
// this method.
std::string annotation_str(TypePrinter printer) const {
if (printer) {
// the printer can return nullopt to fall through to the default impl
if (auto renamed = printer(shared_from_this())) {
return *renamed;
}
}
return annotation_str_impl(printer);
}
std::string annotation_str() const {
// Overload instead of define a default value for `printer` to help
// debuggers out.
return annotation_str(nullptr);
}

// Returns a human readable string that includes additional information like
// "type is inferred rather than explictly defined" to help construct more
// user-friendly messages.
virtual std::string repr_str() const {
return annotation_str();
}

TypeKind kind() const {
return kind_;
}

virtual bool requires_grad() const {
for (const auto& ct : containedTypes()) {
if (ct->requires_grad()) {
return true;
}
}
return false;
}

// Dynamically cast this object to the subclass indicated by the
// template variable, returning nullptr if the cast is invalid.
template <typename T>
std::shared_ptr<T> cast() {
if (T::Kind == kind()) {
return std::static_pointer_cast<T>(shared_from_this());
}
return nullptr;
}
template <typename T>
std::shared_ptr<const T> cast() const {
if (T::Kind == kind()) {
return std::static_pointer_cast<const T>(shared_from_this());
}
return nullptr;
}
template <typename T>
std::shared_ptr<T> expect() {
auto r = cast<T>();
AT_ASSERT(r);
return r;
}
template <typename T>
std::shared_ptr<const T> expect() const {
auto r = cast<const T>();
AT_ASSERT(r);
return r;
}
virtual ~Type() = default;
virtual bool hasFreeVariables() const {
return false;
}
// list of types this type contains, e.g. for a List then element type of a
// list for a tuple, the types of the tuple elements
virtual at::ArrayRef<TypePtr> containedTypes() const {
return {};
}
// create a new version of this type, replacing its contained types with
// contained_types
TypePtr withContained(std::vector<TypePtr> contained_types) {
auto current_contained = containedTypes();
AT_ASSERT(current_contained.size() == contained_types.size());
if (current_contained.equals(contained_types)) {
return shared_from_this();
}
return createWithContained(std::move(contained_types));
}
// per-type constructor, you only need to override this if the
// containedTypes() is not empty
virtual TypePtr createWithContained(
std::vector<TypePtr> contained_types) const {
AT_ERROR(
"type with contained types did not overload createWithContained: ",
str());
}
};

struct AnyType;
using AnyTypePtr = std::shared_ptr<AnyType>;
// Any is the top of the type hierarchy, all other types are subtypes
Expand Down

0 comments on commit e0edfca

Please sign in to comment.