diff --git a/aten/src/ATen/ATen.h b/aten/src/ATen/ATen.h index ae95ef43f21c..8d29a9204420 100644 --- a/aten/src/ATen/ATen.h +++ b/aten/src/ATen/ATen.h @@ -31,3 +31,4 @@ #include #include #include +#include diff --git a/aten/src/ATen/ParallelOpenMP.cpp b/aten/src/ATen/ParallelOpenMP.cpp index 07fc4e279557..261f6cdd46b5 100644 --- a/aten/src/ATen/ParallelOpenMP.cpp +++ b/aten/src/ATen/ParallelOpenMP.cpp @@ -1,4 +1,5 @@ #include +#include #if AT_PARALLEL_OPENMP #include diff --git a/aten/src/ATen/TensorIndexing.h b/aten/src/ATen/TensorIndexing.h index 3890662123a2..f6c3bbbe09cc 100644 --- a/aten/src/ATen/TensorIndexing.h +++ b/aten/src/ATen/TensorIndexing.h @@ -10,6 +10,8 @@ // There is some back story, see https://github.com/pytorch/pytorch/issues/48684 #include +#include + namespace at { namespace indexing { @@ -261,14 +263,15 @@ static inline void recordTensorIndex(const Tensor& tensor, std::vector& (*dim_ptr)++; }; -static inline std::vector typeConvertIndices(const Tensor& self, std::vector&& indices) { - std::vector converted_inds(indices.size()); +static inline c10::List> typeConvertIndices(const Tensor& self, std::vector&& indices) { + c10::List> converted_inds; + converted_inds.reserve(indices.size()); for (size_t i = 0; i < indices.size(); ++i) { const auto &ind = indices[i]; if (ind.defined()) { - converted_inds[i] = ind.to(ind.options().device(self.device())); + converted_inds.push_back(ind.to(ind.options().device(self.device()))); } else { - converted_inds[i] = std::move(indices[i]); + converted_inds.push_back(std::move(indices[i])); } } return converted_inds; diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index 8c82f965ef0f..dfb8e3ac0f32 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -406,7 +406,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) { KERNEL(ADD_NS(cross), "cross", Tensor (const Tensor &, const Tensor &, c10::optional), promote) KERNEL(ADD_NS(dot), "dot", Tensor (const Tensor &, const Tensor &), promote) KERNEL(ADD_NS(equal), "equal", bool (const Tensor &, const Tensor &), promote) - KERNEL_UNBOXED_ONLY(ADD_NS(index_put), "index_put", Tensor (const Tensor &, TensorList, const Tensor &, bool), promote) + KERNEL(ADD_NS(index_put), "index_put", Tensor (const Tensor &, const torch::List>&, const Tensor &, bool), promote) KERNEL(ADD_NS(stack), "stack", Tensor (TensorList, int64_t), promote) KERNEL(ADD_NS(tensordot), "tensordot", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef), promote) diff --git a/aten/src/ATen/core/List.h b/aten/src/ATen/core/List.h index 40f733784fe5..f911722c51e1 100644 --- a/aten/src/ATen/core/List.h +++ b/aten/src/ATen/core/List.h @@ -243,7 +243,7 @@ class List final { * Example: * List a({2, 3, 4}); */ - explicit List(std::initializer_list initial_values); + List(std::initializer_list initial_values); explicit List(ArrayRef initial_values); /** diff --git a/aten/src/ATen/core/List_inl.h b/aten/src/ATen/core/List_inl.h index 3cbd7a310275..ab3ddae55770 100644 --- a/aten/src/ATen/core/List_inl.h +++ b/aten/src/ATen/core/List_inl.h @@ -1,7 +1,7 @@ #pragma once +#include #include -#include namespace c10 { @@ -50,7 +50,17 @@ List::List(TypePtr elementType) namespace impl { template List toTypedList(impl::GenericList list) { - TORCH_INTERNAL_ASSERT(*getTypePtr() == *list.impl_->elementType, "Tried to cast a List<", toString(list.impl_->elementType), "> to a List<", toString(getTypePtr()), ">. Types mismatch."); + // If there's other instances of the list (i.e. list.use_count() > 1), then we have to be invariant + // because upcasting would allow people to add types into the new list that would break the old list. + // However, if there aren't any other instances of this list (i.e. list.use_count() == 1), then we can + // allow upcasting. This can be a perf improvement since we can cast List to List> + // without having to copy it. This is also used to provide backwards compatibility with some old models + // that serialized the index arguments to aten::index, aten::index_put, aten::index_put_ and aten::index_put_impl_ + // as List before we changed that argument to be List>. When deserializing, we + // have list.use_count() == 1 and can deserialize the List directly as List>. + TORCH_CHECK(*list.impl_->elementType == *getTypePtr() + || (list.use_count() == 1 && list.impl_->elementType->isSubtypeOf(getTypePtr())) + , "Tried to cast a List<", toString(list.impl_->elementType), "> to a List<", toString(getTypePtr()), ">. Types mismatch."); return List(std::move(list.impl_)); } @@ -312,3 +322,5 @@ void List::unsafeSetElementType(TypePtr t) { impl_->elementType = std::move(t); } } + +#include diff --git a/aten/src/ATen/core/Variadic.h b/aten/src/ATen/core/Variadic.h index b49d94bba1c8..d33f3d575177 100644 --- a/aten/src/ATen/core/Variadic.h +++ b/aten/src/ATen/core/Variadic.h @@ -6,6 +6,7 @@ #include #include +#include namespace at { @@ -56,6 +57,15 @@ struct IterArgs { } } + template + void operator()(const torch::List& args) { + for (const auto& arg : args) { + self()(arg); + if (self().short_circuit()) + return; + } + } + // NB: we need to specify std::vector manually as C++ won't // do an implicit conversion to make a template deduction go through. template diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index 8065300f0b32..f99dc3c07058 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -17,6 +17,7 @@ namespace c10 { #define FORALL_NS_SYMBOLS(_) \ _(namespaces, prim) \ _(namespaces, aten) \ + _(namespaces, cuda) \ _(namespaces, onnx) \ _(namespaces, attr) \ _(namespaces, scope) \ @@ -284,6 +285,9 @@ namespace c10 { _(aten, zero_) \ _(aten, fill_) \ _(aten, masked_fill_) \ + _(cuda, _set_device) \ + _(cuda, set_stream) \ + _(cuda, _current_device) \ _(aten, swapaxes) \ _(aten, swapaxes_) \ _(aten, swapdims) \ @@ -383,6 +387,7 @@ namespace c10 { #define FORALL_NS_SYMBOLS(_) \ _(namespaces, prim) \ _(namespaces, aten) \ + _(namespaces, cuda) \ _(namespaces, onnx) \ _(namespaces, attr) \ _(namespaces, scope) \ @@ -453,6 +458,7 @@ struct TORCH_API Symbol { // (and if it's not, you should add it to the built-ins list above.) static Symbol attr(const std::string & s); static Symbol aten(const std::string & s); + static Symbol cuda(const std::string & s); static Symbol onnx(const std::string & s); static Symbol prim(const std::string & s); static Symbol user(const std::string & s); @@ -463,6 +469,7 @@ struct TORCH_API Symbol { bool is_attr() const; bool is_aten() const; + bool is_cuda() const; bool is_prim() const; bool is_onnx() const; bool is_user() const; @@ -523,6 +530,7 @@ FORALL_NS_SYMBOLS(DEFINE_SYMBOL) inline Symbol Symbol::attr(const std::string & s) { return Symbol::fromQualString("attr::" + s); } inline Symbol Symbol::aten(const std::string & s) { return Symbol::fromQualString("aten::" + s); } +inline Symbol Symbol::cuda(const std::string & s) { return Symbol::fromQualString("cuda::" + s); } inline Symbol Symbol::onnx(const std::string & s) { return Symbol::fromQualString("onnx::" + s); } inline Symbol Symbol::prim(const std::string & s) { return Symbol::fromQualString("prim::" + s); } inline Symbol Symbol::scope(const std::string & s) { return Symbol::fromQualString("scope::" + s); } @@ -531,6 +539,7 @@ inline Symbol Symbol::caffe2(const std::string & s) { return Symbol::fromQualStr inline Symbol Symbol::dimname(const std::string & s) { return Symbol::fromQualString("dimname::" + s); } inline bool Symbol::is_attr() const { return ns() == namespaces::attr; } inline bool Symbol::is_aten() const { return ns() == namespaces::aten; } +inline bool Symbol::is_cuda() const { return ns() == namespaces::cuda; } inline bool Symbol::is_prim() const { return ns() == namespaces::prim; } inline bool Symbol::is_onnx() const { return ns() == namespaces::onnx; } inline bool Symbol::is_user() const { return ns() == namespaces::user; } diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index f6902cd4beb6..a3ae813616e0 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -1,10 +1,11 @@ #pragma once +#include #include #include #include -#include #include +#include #include #include @@ -17,197 +18,17 @@ struct ClassType; namespace torch { namespace jit { struct CompilationUnit; +struct Function; } // namespace jit } // namespace torch namespace c10 { +struct IValue; struct FunctionSchema; struct NamedType; using OptNameList = c10::optional>; -#define C10_FORALL_TYPES(_) \ - _(AnyType) \ - _(EnumType) \ - _(AnyEnumType) \ - _(TensorType) \ - _(StorageType) \ - _(TupleType) \ - _(ListType) \ - _(DictType) \ - _(NumberType) \ - _(FloatType) \ - _(FutureType) \ - _(RRefType) \ - _(IntType) \ - _(NoneType) \ - _(StringType) \ - _(GeneratorType) \ - _(QuantizerType) \ - _(BoolType) \ - _(OptionalType) \ - _(VarType) \ - _(DeviceObjType) \ - _(StreamObjType) \ - _(FunctionType) \ - _(ClassType) \ - _(PyObjectType) \ - _(CapsuleType) \ - _(InterfaceType) \ - _(QSchemeType) \ - _(LayoutType) \ - _(ScalarTypeType) \ - _(AnyListType) \ - _(AnyTupleType) \ - _(AnyClassType) - -enum class TypeKind { -#define DEFINE_TYPE(T) T, - C10_FORALL_TYPES(DEFINE_TYPE) -#undef DEFINE_TYPE -}; - -TORCH_API const char* typeKindToString(TypeKind kind); - -struct Type; -using TypePtr = std::shared_ptr; -using ConstTypePtr = std::shared_ptr; - -// Use this to customize how a Type is printed using `annotation_str()`. If -// c10::nullopt is returned, `annotation_str()` falls through to its default -// implementation. -using TypePrinter = - std::function(const ConstTypePtr&)>; - -struct TORCH_API Type : std::enable_shared_from_this { - private: - TypeKind kind_; - - protected: - Type(TypeKind kind) : kind_(kind) {} - - virtual std::string annotation_str_impl(TypePrinter printer) const { - return str(); - } - - public: - virtual bool operator==(const Type& rhs) const = 0; - - // subtyping relation. By default, we return true for the case - // when the type is exactly equal or if this <: T where rhs = Optional[T] - - // if this returns false and the why_not stream is non-null, it contains - // additional details that describe why this is not a subtype of 'rhs'. - // This additional information should only contain details that are not obvious - // from the annotation_str() that describes the type. For instance it is clear that `int <: str` is false - // but not clear why `Foo <: InterfaceBar` might be false. - virtual bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const; - virtual bool is_module() const; - bool isSubtypeOf(const TypePtr& rhs) const { - return isSubtypeOfExt(rhs, nullptr); - } - - // How this type will appear in FunctionSchema declarations - virtual std::string str() const = 0; - - // How this type will appear as if it were a type annotation in Python - // which is sometimes different than how it appears in declarations (e.g. - // int[] vs List[int]) - // - // Takes a custom printer that users can pass in to customize the output of - // this method. - std::string annotation_str(TypePrinter printer) const { - if (printer) { - // the printer can return nullopt to fall through to the default impl - if (auto renamed = printer(shared_from_this())) { - return *renamed; - } - } - return annotation_str_impl(printer); - } - std::string annotation_str() const { - // Overload instead of define a default value for `printer` to help - // debuggers out. - return annotation_str(nullptr); - } - - // Returns a human readable string that includes additional information like - // "type is inferred rather than explictly defined" to help construct more - // user-friendly messages. - virtual std::string repr_str() const { - return annotation_str(); - } - - TypeKind kind() const { - return kind_; - } - - virtual bool requires_grad() const { - for (const auto& ct : containedTypes()) { - if (ct->requires_grad()) { - return true; - } - } - return false; - } - - // Dynamically cast this object to the subclass indicated by the - // template variable, returning nullptr if the cast is invalid. - template - std::shared_ptr cast() { - if (T::Kind == kind()) { - return std::static_pointer_cast(shared_from_this()); - } - return nullptr; - } - template - std::shared_ptr cast() const { - if (T::Kind == kind()) { - return std::static_pointer_cast(shared_from_this()); - } - return nullptr; - } - template - std::shared_ptr expect() { - auto r = cast(); - AT_ASSERT(r); - return r; - } - template - std::shared_ptr expect() const { - auto r = cast(); - AT_ASSERT(r); - return r; - } - virtual ~Type() = default; - virtual bool hasFreeVariables() const { - return false; - } - // list of types this type contains, e.g. for a List then element type of a - // list for a tuple, the types of the tuple elements - virtual at::ArrayRef containedTypes() const { - return {}; - } - // create a new version of this type, replacing its contained types with - // contained_types - TypePtr withContained(std::vector contained_types) { - auto current_contained = containedTypes(); - AT_ASSERT(current_contained.size() == contained_types.size()); - if (current_contained.equals(contained_types)) { - return shared_from_this(); - } - return createWithContained(std::move(contained_types)); - } - // per-type constructor, you only need to override this if the - // containedTypes() is not empty - virtual TypePtr createWithContained( - std::vector contained_types) const { - AT_ERROR( - "type with contained types did not overload createWithContained: ", - str()); - } -}; - struct AnyType; using AnyTypePtr = std::shared_ptr; // Any is the top of the type hierarchy, all other types are subtypes diff --git a/aten/src/ATen/core/jit_type_base.h b/aten/src/ATen/core/jit_type_base.h new file mode 100644 index 000000000000..37da9ad7ef8d --- /dev/null +++ b/aten/src/ATen/core/jit_type_base.h @@ -0,0 +1,195 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +#define C10_FORALL_TYPES(_) \ + _(AnyType) \ + _(EnumType) \ + _(AnyEnumType) \ + _(TensorType) \ + _(StorageType) \ + _(TupleType) \ + _(ListType) \ + _(DictType) \ + _(NumberType) \ + _(FloatType) \ + _(FutureType) \ + _(RRefType) \ + _(IntType) \ + _(NoneType) \ + _(StringType) \ + _(GeneratorType) \ + _(QuantizerType) \ + _(BoolType) \ + _(OptionalType) \ + _(VarType) \ + _(DeviceObjType) \ + _(StreamObjType) \ + _(FunctionType) \ + _(ClassType) \ + _(PyObjectType) \ + _(CapsuleType) \ + _(InterfaceType) \ + _(QSchemeType) \ + _(LayoutType) \ + _(ScalarTypeType) \ + _(AnyListType) \ + _(AnyTupleType) \ + _(AnyClassType) + +enum class TypeKind { +#define DEFINE_TYPE(T) T, + C10_FORALL_TYPES(DEFINE_TYPE) +#undef DEFINE_TYPE +}; + +TORCH_API const char* typeKindToString(TypeKind kind); + +struct Type; +using TypePtr = std::shared_ptr; +using ConstTypePtr = std::shared_ptr; + +// Use this to customize how a Type is printed using `annotation_str()`. If +// c10::nullopt is returned, `annotation_str()` falls through to its default +// implementation. +using TypePrinter = + std::function(const ConstTypePtr&)>; + +struct TORCH_API Type : std::enable_shared_from_this { + private: + TypeKind kind_; + + protected: + Type(TypeKind kind) : kind_(kind) {} + + virtual std::string annotation_str_impl(TypePrinter printer) const { + return str(); + } + + public: + virtual bool operator==(const Type& rhs) const = 0; + + // subtyping relation. By default, we return true for the case + // when the type is exactly equal or if this <: T where rhs = Optional[T] + + // if this returns false and the why_not stream is non-null, it contains + // additional details that describe why this is not a subtype of 'rhs'. + // This additional information should only contain details that are not obvious + // from the annotation_str() that describes the type. For instance it is clear that `int <: str` is false + // but not clear why `Foo <: InterfaceBar` might be false. + virtual bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const; + virtual bool is_module() const; + bool isSubtypeOf(const TypePtr& rhs) const { + return isSubtypeOfExt(rhs, nullptr); + } + + // How this type will appear in FunctionSchema declarations + virtual std::string str() const = 0; + + // How this type will appear as if it were a type annotation in Python + // which is sometimes different than how it appears in declarations (e.g. + // int[] vs List[int]) + // + // Takes a custom printer that users can pass in to customize the output of + // this method. + std::string annotation_str(TypePrinter printer) const { + if (printer) { + // the printer can return nullopt to fall through to the default impl + if (auto renamed = printer(shared_from_this())) { + return *renamed; + } + } + return annotation_str_impl(printer); + } + std::string annotation_str() const { + // Overload instead of define a default value for `printer` to help + // debuggers out. + return annotation_str(nullptr); + } + + // Returns a human readable string that includes additional information like + // "type is inferred rather than explictly defined" to help construct more + // user-friendly messages. + virtual std::string repr_str() const { + return annotation_str(); + } + + TypeKind kind() const { + return kind_; + } + + virtual bool requires_grad() const { + for (const auto& ct : containedTypes()) { + if (ct->requires_grad()) { + return true; + } + } + return false; + } + + // Dynamically cast this object to the subclass indicated by the + // template variable, returning nullptr if the cast is invalid. + template + std::shared_ptr cast() { + if (T::Kind == kind()) { + return std::static_pointer_cast(shared_from_this()); + } + return nullptr; + } + template + std::shared_ptr cast() const { + if (T::Kind == kind()) { + return std::static_pointer_cast(shared_from_this()); + } + return nullptr; + } + template + std::shared_ptr expect() { + auto r = cast(); + AT_ASSERT(r); + return r; + } + template + std::shared_ptr expect() const { + auto r = cast(); + AT_ASSERT(r); + return r; + } + virtual ~Type() = default; + virtual bool hasFreeVariables() const { + return false; + } + // list of types this type contains, e.g. for a List then element type of a + // list for a tuple, the types of the tuple elements + virtual at::ArrayRef containedTypes() const { + return {}; + } + // create a new version of this type, replacing its contained types with + // contained_types + TypePtr withContained(std::vector contained_types) { + auto current_contained = containedTypes(); + AT_ASSERT(current_contained.size() == contained_types.size()); + if (current_contained.equals(contained_types)) { + return shared_from_this(); + } + return createWithContained(std::move(contained_types)); + } + // per-type constructor, you only need to override this if the + // containedTypes() is not empty + virtual TypePtr createWithContained( + std::vector contained_types) const { + AT_ERROR( + "type with contained types did not overload createWithContained: ", + str()); + } +}; + +} diff --git a/aten/src/ATen/native/Embedding.cpp b/aten/src/ATen/native/Embedding.cpp index bf74e8b356c7..a4854e1ced4d 100644 --- a/aten/src/ATen/native/Embedding.cpp +++ b/aten/src/ATen/native/Embedding.cpp @@ -68,7 +68,7 @@ Tensor embedding_sparse_backward( Tensor indices = indices_; Tensor grad = grad_; if (padding_idx != -1) { - auto c = indices != padding_idx; + torch::List> c({indices != padding_idx}); indices = indices.index(c); grad = grad.index(c); } diff --git a/aten/src/ATen/native/IndexingUtils.h b/aten/src/ATen/native/IndexingUtils.h index 94d61b02dd0b..92f6957f25ad 100644 --- a/aten/src/ATen/native/IndexingUtils.h +++ b/aten/src/ATen/native/IndexingUtils.h @@ -1,6 +1,7 @@ #pragma once #include #include +#include #include @@ -15,40 +16,45 @@ static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask, } -static std::vector expandTensors(const Tensor & self, TensorList indices) { +static std::vector expandTensors(const Tensor & self, const torch::List>& indices) { // If indices come in as ByteTensor or BoolTensor (masks), expand them into the equivalent indexing by LongTensors std::vector result; - for (const auto & index : indices) { - if (index.scalar_type() == kByte || index.scalar_type() == kBool) { - if (index.scalar_type() == kByte) { - TORCH_WARN("indexing with dtype torch.uint8 is now deprecated," \ - " please use a dtype torch.bool instead."); - } - // The sizes of the ByteTensor mask or bool tensor must match the sizes of the - // corresponding dimensions in self - for (int64_t j = 0; j < index.dim(); j++) { - int64_t srcIdx = result.size() + j; - if (index.size(j) != self.size(srcIdx)) { - invalid_mask(self, srcIdx, index, j); + for (c10::optional index_opt : indices) { + if (!index_opt.has_value()) { + result.emplace_back(); + } else { + Tensor index = std::move(*index_opt); + if (index.scalar_type() == kByte || index.scalar_type() == kBool) { + if (index.scalar_type() == kByte) { + TORCH_WARN("indexing with dtype torch.uint8 is now deprecated," \ + " please use a dtype torch.bool instead."); } + // The sizes of the ByteTensor mask or bool tensor must match the sizes of the + // corresponding dimensions in self + for (int64_t j = 0; j < index.dim(); j++) { + int64_t srcIdx = result.size() + j; + if (index.size(j) != self.size(srcIdx)) { + invalid_mask(self, srcIdx, index, j); + } + } + // Replace with nonzeros + auto nonzero = index.nonzero(); + for (int64_t j = 0; j < index.dim(); j++) { + result.emplace_back(nonzero.select(1, j)); + } + } else { + result.emplace_back(std::move(index)); } - // Replace with nonzeros - auto nonzero = index.nonzero(); - for (int64_t j = 0; j < index.dim(); j++) { - result.emplace_back(nonzero.select(1, j)); - } - } else { - result.emplace_back(index); } } return result; } -static void checkIndexTensorTypes(TensorList indices) { - for (auto& tensor : indices) { - if (tensor.defined()) { - auto scalarType = tensor.scalar_type(); +static void checkIndexTensorTypes(const torch::List>& indices) { + for (c10::optional tensor : indices) { + if (tensor.has_value() && tensor->defined()) { + auto scalarType = tensor->scalar_type(); if (scalarType != kLong && scalarType != kByte && scalarType != kBool) { TORCH_CHECK_INDEX(false, "tensors used as indices must be long, byte or bool tensors"); } @@ -56,6 +62,15 @@ static void checkIndexTensorTypes(TensorList indices) { } } +inline torch::List> toListOfOptionalTensors(ArrayRef list) { + torch::List> result; + result.reserve(list.size()); + for (const Tensor& a : list) { + result.push_back(a); + } + return result; +} + static bool hasContiguousSubspace(TensorList tl) { // true if all the non-null tensors are adjacent auto isDefined = [](const Tensor & tensor){ return tensor.defined(); }; diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index da8d2bd6db47..a37d1046bac2 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -73,7 +74,8 @@ Tensor logdet(const Tensor& self) { // U is singular when U(i, i) = 0 for some i in [1, self.size(-1)]. Tensor logdet_vals = diag_U.abs_().log_().sum(-1); if (self.dim() > 2) { - logdet_vals.index_put_((det_sign < 0).nonzero_numpy(), at::full({}, NAN, self.options())); + auto indices = toListOfOptionalTensors((det_sign < 0).nonzero_numpy()); + logdet_vals.index_put_(std::move(indices), at::full({}, NAN, self.options())); } else if (det_sign.item() < 0) { logdet_vals.fill_(NAN); } diff --git a/aten/src/ATen/native/Pow.cpp b/aten/src/ATen/native/Pow.cpp index bfc5f910e093..4d1601d3e6a0 100644 --- a/aten/src/ATen/native/Pow.cpp +++ b/aten/src/ATen/native/Pow.cpp @@ -31,11 +31,9 @@ Tensor& pow_out(Tensor& result, const Tensor& base, Scalar exp) { "result type ", common_dtype, "can't be cast to the desired output type ", result.scalar_type()); - auto exponent = (exp.isComplex()) ? exp.toComplexDouble() : exp.toDouble(); - - if (exponent == 0.0) { + if (exp.equal(0.0)) { result.resize_as_(base).fill_(1); - } else if (exponent == 1.0) { + } else if (exp.equal(1.0)) { result.resize_as_(base).copy_(base); } else { auto iter = TensorIterator::unary_op(result, base.to(common_dtype)); diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index 1d9f9d9d2a12..2d79a4e3713f 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -206,7 +206,7 @@ AdvancedIndex::AdvancedIndex(const Tensor& src, TensorList indices_list) } } -static AdvancedIndex make_info(Tensor self, TensorList orig) { +static AdvancedIndex make_info(Tensor self, const torch::List>& orig) { checkIndexTensorTypes(orig); // first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors auto indices = expandTensors(self, orig); @@ -281,7 +281,7 @@ static TensorIterator make_index_out_iterator(const AdvancedIndex& info, Tensor& return config.build(); } -Tensor index(const Tensor & self, TensorList indices) { +Tensor index(const Tensor & self, const torch::List>& indices) { TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")"); auto info = make_info(self, indices); @@ -290,7 +290,7 @@ Tensor index(const Tensor & self, TensorList indices) { return iter.output(); } -Tensor quantized_index(const Tensor & self, TensorList indices) { +Tensor quantized_index(const Tensor & self, const torch::List>& indices) { TORCH_INTERNAL_ASSERT( self.qscheme() == c10::kPerTensorAffine || self.qscheme() == c10::kPerTensorSymmetric, @@ -311,12 +311,14 @@ Tensor quantized_index(const Tensor & self, TensorList indices) { res, self.q_scale(), self.q_zero_point(), self.scalar_type()); } -Tensor& index_out(Tensor& result, const Tensor & self, TensorList indices) { +Tensor& index_out(Tensor& result, const Tensor & self, const torch::List>& indices) { TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")"); at::assert_no_internal_overlap(result); at::assert_no_overlap(result, self); - for (auto& index: indices) { - at::assert_no_overlap(result, index); + for (const c10::optional& index: indices) { + if (index.has_value()) { + at::assert_no_overlap(result, *index); + } } auto info = make_info(self, indices); @@ -325,11 +327,11 @@ Tensor& index_out(Tensor& result, const Tensor & self, TensorList indices) { return result; } -Tensor index_put(const Tensor & self, TensorList indices, const Tensor & value, bool accumulate) { +Tensor index_put(const Tensor & self, const torch::List>& indices, const Tensor & value, bool accumulate) { return self.clone(at::MemoryFormat::Preserve).index_put_(indices, value, accumulate); } -Tensor & _index_put_impl_(Tensor & self, TensorList indices, const Tensor & value, const bool accumulate, const bool unsafe) { +Tensor & _index_put_impl_(Tensor & self, const torch::List>& indices, const Tensor & value, const bool accumulate, const bool unsafe) { TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")"); if (at::has_internal_overlap(self) == MemOverlap::YES) { TORCH_WARN( @@ -338,8 +340,10 @@ Tensor & _index_put_impl_(Tensor & self, TensorList indices, const Tensor & valu "This also applies to advanced indexing e.g. tensor[indices] = tensor"); } at::assert_no_overlap(self, value); - for (auto& index: indices) { - at::assert_no_overlap(self, index); + for (const c10::optional& index: indices) { + if (index.has_value()) { + at::assert_no_overlap(self, *index); + } } if (accumulate && self.device().type() == kCUDA) { @@ -356,7 +360,7 @@ Tensor & _index_put_impl_(Tensor & self, TensorList indices, const Tensor & valu } -Tensor & index_put_(Tensor & self, TensorList indices, const Tensor & value, const bool accumulate) { +Tensor & index_put_(Tensor & self, const torch::List>& indices, const Tensor & value, const bool accumulate) { return at::_index_put_impl_(self, indices, value, accumulate, /*unsafe=*/false); } diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.h b/aten/src/ATen/native/TensorAdvancedIndexing.h index 560b46162546..0e0958606de1 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.h +++ b/aten/src/ATen/native/TensorAdvancedIndexing.h @@ -15,7 +15,7 @@ enum class SCATTER_GATHER_OP: uint8_t {REDUCE_ADD, REDUCE_MULTIPLY}; using index_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides); using index_put_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides, bool accumulate); -using index_put_accum_fn = void(*)(Tensor &, TensorList , const Tensor &, bool unsafe); +using index_put_accum_fn = void(*)(Tensor &, const c10::List> &, const Tensor &, bool unsafe); using masked_fill_fn = void(*)(TensorIterator &, Scalar scalar); using masked_select_fn = void(*)(TensorIterator &, int64_t orig_stride); @@ -42,6 +42,6 @@ DECLARE_DISPATCH(scatter_add_fn, scatter_add_stub); DECLARE_DISPATCH(scatter_reduce_fn, scatter_reduce_stub); DECLARE_DISPATCH(scatter_scalar_reduce_fn, scatter_scalar_reduce_stub); -TORCH_API Tensor& index_out(Tensor& result, const Tensor & self, TensorList indices); +TORCH_API Tensor& index_out(Tensor& result, const Tensor & self, const c10::List>& indices); }} // namespace at::native diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index e6dd1bc4afde..0f6da7e4292a 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -326,8 +326,12 @@ Tensor& reciprocal_out(Tensor& result, const Tensor& self) { return unary_op_imp Tensor reciprocal(const Tensor& self) { return unary_op_impl_float(self, reciprocal_stub); } Tensor& reciprocal_(Tensor& self) { return unary_op_impl_(self, at::reciprocal_out); } -Tensor& rsqrt_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, rsqrt_stub); } -Tensor rsqrt(const Tensor& self) { return unary_op_impl(self, at::rsqrt_out); } +Tensor& rsqrt_out(Tensor& result, const Tensor& self) { + return unary_op_impl_float_out(result, self, rsqrt_stub); +} +Tensor rsqrt(const Tensor& self) { + return unary_op_impl_float(self, rsqrt_stub); +} Tensor& rsqrt_(Tensor& self) { return unary_op_impl_(self, at::rsqrt_out); } Tensor& sign_out(Tensor& result, const Tensor& self) { diff --git a/aten/src/ATen/native/cpu/PowKernel.cpp b/aten/src/ATen/native/cpu/PowKernel.cpp index b7ec099a80da..6f0d153e978a 100644 --- a/aten/src/ATen/native/cpu/PowKernel.cpp +++ b/aten/src/ATen/native/cpu/PowKernel.cpp @@ -63,7 +63,7 @@ void pow_tensor_scalar_kernel(TensorIterator& iter, Scalar exp_scalar) { ); } else if (exp == -0.5) { cpu_kernel_vec(iter, - [](scalar_t base) -> scalar_t { + [](scalar_t base) __ubsan_ignore_float_divide_by_zero__ -> scalar_t { return 1.0 / std::sqrt(base); }, [](Vec base) -> Vec { return base.rsqrt(); } diff --git a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp index 5f96e01ab319..32033abcd4e2 100644 --- a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp @@ -225,7 +225,7 @@ static void norm_kernel_tensor_iterator_impl( binary_kernel_reduce( iter, AbsMaxOps(), - std::numeric_limits::min() + acc_t(0) ); }); } else if (val == -INFINITY) { diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp index 049b3eff6b5b..32ebaf7752f7 100644 --- a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp @@ -587,10 +587,10 @@ static void random_full_64_bits_range_kernel(TensorIterator& iter, c10::optional } static void rsqrt_kernel(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), "rsqrt_cpu", [&] { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.common_dtype(), "rsqrt_cpu", [&] { cpu_kernel_vec( iter, - [=](scalar_t a) -> scalar_t { + [=](scalar_t a) __ubsan_ignore_float_divide_by_zero__ -> scalar_t { return (static_cast(1)) / std::sqrt(a); }, [=](Vec256 a) { return a.rsqrt(); }); diff --git a/aten/src/ATen/native/cuda/IndexKernel.cu b/aten/src/ATen/native/cuda/IndexKernel.cu index cb4aa644fee2..d88f202487af 100644 --- a/aten/src/ATen/native/cuda/IndexKernel.cu +++ b/aten/src/ATen/native/cuda/IndexKernel.cu @@ -190,7 +190,7 @@ static Tensor & masked_select_out_cuda_impl(Tensor & result, const Tensor & self Tensor _mask = (mask.dim() == 0) ? mask.unsqueeze(0) : mask; Tensor _self = (self.dim() == 0) ? self.unsqueeze(0) : self; std::tie(_mask, _self) = expand_outplace(_mask, _self); - at::native::index_out(result, _self, _mask); + at::native::index_out(result, _self, c10::List>({_mask})); return result; } diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu index d630d727019f..e372f8bdb697 100644 --- a/aten/src/ATen/native/cuda/Indexing.cu +++ b/aten/src/ATen/native/cuda/Indexing.cu @@ -160,7 +160,7 @@ computeLinearIndex(const Tensor & src, TensorList indices, bool check_range) { } -static std::tuple> makeLinearIndex(Tensor self, TensorList orig, bool check_range) { +static std::tuple> makeLinearIndex(Tensor self, const c10::List>& orig, bool check_range) { checkIndexTensorTypes(orig); // first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors auto indices = expandTensors(self, orig); @@ -184,7 +184,7 @@ static std::tuple>& indices, const Tensor & value, bool unsafe) { if (indices.size() > (size_t)self.dim()) { TORCH_CHECK_INDEX(false, "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")"); } diff --git a/aten/src/ATen/native/cuda/ReduceNormKernel.cu b/aten/src/ATen/native/cuda/ReduceNormKernel.cu index 3953f16b69c9..3a24f00f6ebf 100644 --- a/aten/src/ATen/native/cuda/ReduceNormKernel.cu +++ b/aten/src/ATen/native/cuda/ReduceNormKernel.cu @@ -28,7 +28,7 @@ void norm_kernel_cuda_impl(TensorIterator& iter, Scalar val) { } else if (p == static_cast(2)) { gpu_reduce_kernel(iter, NormTwoOps(), 0); } else if (p == static_cast(INFINITY)) { - gpu_reduce_kernel(iter, AbsMaxOps(), std::numeric_limits::min()); + gpu_reduce_kernel(iter, AbsMaxOps(), 0); } else if (p == static_cast(-INFINITY)) { gpu_reduce_kernel(iter, AbsMinOps(), std::numeric_limits::max()); } else { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index a5b945399da8..6b0aaa8f4d9b 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -2226,6 +2226,7 @@ use_c10_dispatcher: full - func: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor + use_c10_dispatcher: full variants: function, method dispatch: CPU, CUDA: index @@ -2254,6 +2255,7 @@ variants: function, method - func: index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!) + use_c10_dispatcher: full variants: function, method dispatch: DefaultBackend: index_put_ @@ -2264,9 +2266,11 @@ # - Tensor & Tensor::index_put_(std::initializer_list indices, Scalar v) - func: index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor + use_c10_dispatcher: full variants: function, method - func: _index_put_impl_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor(a!) + use_c10_dispatcher: full variants: function dispatch: CPU, CUDA: _index_put_impl_ diff --git a/aten/src/ATen/native/quantized/cpu/qconv.cpp b/aten/src/ATen/native/quantized/cpu/qconv.cpp index b7d893ad55fc..05762bfb036f 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv.cpp @@ -746,7 +746,7 @@ at::Tensor PackedConvWeightsQnnp::apply_impl( run_status == pytorch_qnnp_status_success, "failed to run quantized::conv2d (qnnpack) operator"); - return output.contiguous(act.suggest_memory_format()); + return output; } template at::Tensor PackedConvWeightsQnnp<2>::apply( diff --git a/aten/src/ATen/native/sparse/SparseTensor.cpp b/aten/src/ATen/native/sparse/SparseTensor.cpp index d621efafee41..fb7e16539c15 100644 --- a/aten/src/ATen/native/sparse/SparseTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseTensor.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include @@ -14,7 +15,6 @@ namespace at { namespace native { using namespace at::sparse; - /****************************************************************************** * access methods ******************************************************************************/ @@ -328,7 +328,7 @@ SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){ Tensor values; if (self.dim() > 0) { - std::vector ix = indices.chunk(indices.size(0), 0); + auto ix = toListOfOptionalTensors(indices.chunk(indices.size(0), 0)); values = self.index(ix).squeeze(0).clone(at::MemoryFormat::Preserve); } else { AT_ASSERT(nz.sizes().equals({0, 1})); diff --git a/aten/src/ATen/native/vulkan/glsl/KO4C4HW_to_image.glsl b/aten/src/ATen/native/vulkan/glsl/KO4C4HW_to_image.glsl index 58394dca19da..2c02e034603e 100644 --- a/aten/src/ATen/native/vulkan/glsl/KO4C4HW_to_image.glsl +++ b/aten/src/ATen/native/vulkan/glsl/KO4C4HW_to_image.glsl @@ -1,7 +1,6 @@ #version 450 core #define PRECISION $precision layout(std430) buffer; -layout(std430) uniform; layout(set = 0, rgba16f, binding = 0) writeonly PRECISION uniform image3D uOutput; layout(set = 0, binding = 1) readonly buffer kernel { vec4 data[]; diff --git a/aten/src/ATen/native/vulkan/glsl/adaptive_avg_pool2d.glsl b/aten/src/ATen/native/vulkan/glsl/adaptive_avg_pool2d.glsl index d5b9af843dbe..75243a69bca3 100644 --- a/aten/src/ATen/native/vulkan/glsl/adaptive_avg_pool2d.glsl +++ b/aten/src/ATen/native/vulkan/glsl/adaptive_avg_pool2d.glsl @@ -2,7 +2,6 @@ #define PRECISION $precision layout(std430) buffer; -layout(std430) uniform; /* Qualifiers: layout - storage - precision - memory */ diff --git a/aten/src/ATen/native/vulkan/glsl/add.glsl b/aten/src/ATen/native/vulkan/glsl/add.glsl index 8dcff0476edf..361927373a49 100644 --- a/aten/src/ATen/native/vulkan/glsl/add.glsl +++ b/aten/src/ATen/native/vulkan/glsl/add.glsl @@ -2,7 +2,6 @@ #define PRECISION $precision layout(std430) buffer; -layout(std430) uniform; /* Qualifiers: layout - storage - precision - memory */ diff --git a/aten/src/ATen/native/vulkan/glsl/add_.glsl b/aten/src/ATen/native/vulkan/glsl/add_.glsl index ed82d0cbe87b..d6360a376c58 100644 --- a/aten/src/ATen/native/vulkan/glsl/add_.glsl +++ b/aten/src/ATen/native/vulkan/glsl/add_.glsl @@ -2,7 +2,6 @@ #define PRECISION $precision layout(std430) buffer; -layout(std430) uniform; /* Qualifiers: layout - storage - precision - memory */ diff --git a/aten/src/ATen/native/vulkan/glsl/add_scalar.glsl b/aten/src/ATen/native/vulkan/glsl/add_scalar.glsl index 8882ba0d8ff2..735086a8150a 100644 --- a/aten/src/ATen/native/vulkan/glsl/add_scalar.glsl +++ b/aten/src/ATen/native/vulkan/glsl/add_scalar.glsl @@ -2,7 +2,6 @@ #define PRECISION $precision layout(std430) buffer; -layout(std430) uniform; /* Qualifiers: layout - storage - precision - memory */ diff --git a/aten/src/ATen/native/vulkan/glsl/add_scalar_.glsl b/aten/src/ATen/native/vulkan/glsl/add_scalar_.glsl index bffd680669fb..a418a28bb5c3 100644 --- a/aten/src/ATen/native/vulkan/glsl/add_scalar_.glsl +++ b/aten/src/ATen/native/vulkan/glsl/add_scalar_.glsl @@ -2,7 +2,6 @@ #define PRECISION $precision layout(std430) buffer; -layout(std430) uniform; /* Qualifiers: layout - storage - precision - memory */ diff --git a/aten/src/ATen/native/vulkan/glsl/addmm.glsl b/aten/src/ATen/native/vulkan/glsl/addmm.glsl index 61f76fa8cf5d..a8f09252a167 100644 --- a/aten/src/ATen/native/vulkan/glsl/addmm.glsl +++ b/aten/src/ATen/native/vulkan/glsl/addmm.glsl @@ -2,7 +2,6 @@ #define PRECISION $precision layout(std430) buffer; -layout(std430) uniform; /* Qualifiers: layout - storage - precision - memory */ diff --git a/aten/src/ATen/native/vulkan/glsl/avg_pool2d.glsl b/aten/src/ATen/native/vulkan/glsl/avg_pool2d.glsl index df2bbcf18014..5de8cf13225f 100644 --- a/aten/src/ATen/native/vulkan/glsl/avg_pool2d.glsl +++ b/aten/src/ATen/native/vulkan/glsl/avg_pool2d.glsl @@ -2,7 +2,6 @@ #define PRECISION $precision layout(std430) buffer; -layout(std430) uniform; /* Qualifiers: layout - storage - precision - memory */ diff --git a/aten/src/ATen/native/vulkan/glsl/clamp.glsl b/aten/src/ATen/native/vulkan/glsl/clamp.glsl index c394dfd26627..52c2d2d96c26 100644 --- a/aten/src/ATen/native/vulkan/glsl/clamp.glsl +++ b/aten/src/ATen/native/vulkan/glsl/clamp.glsl @@ -2,7 +2,6 @@ #define PRECISION $precision layout(std430) buffer; -layout(std430) uniform; /* Qualifiers: layout - storage - precision - memory */ diff --git a/aten/src/ATen/native/vulkan/glsl/clamp_.glsl b/aten/src/ATen/native/vulkan/glsl/clamp_.glsl index b16258685114..3f138bb93ec6 100644 --- a/aten/src/ATen/native/vulkan/glsl/clamp_.glsl +++ b/aten/src/ATen/native/vulkan/glsl/clamp_.glsl @@ -2,7 +2,6 @@ #define PRECISION $precision layout(std430) buffer; -layout(std430) uniform; /* Qualifiers: layout - storage - precision - memory */ diff --git a/aten/src/ATen/native/vulkan/glsl/conv2d.glsl b/aten/src/ATen/native/vulkan/glsl/conv2d.glsl index 9646eb8c9f19..bb2508aefe65 100644 --- a/aten/src/ATen/native/vulkan/glsl/conv2d.glsl +++ b/aten/src/ATen/native/vulkan/glsl/conv2d.glsl @@ -2,7 +2,6 @@ #define PRECISION $precision layout(std430) buffer; -layout(std430) uniform; /* Qualifiers: layout - storage - precision - memory */ diff --git a/aten/src/ATen/native/vulkan/glsl/conv2d_dw.glsl b/aten/src/ATen/native/vulkan/glsl/conv2d_dw.glsl index fe50262f7d46..0f49515718b2 100644 --- a/aten/src/ATen/native/vulkan/glsl/conv2d_dw.glsl +++ b/aten/src/ATen/native/vulkan/glsl/conv2d_dw.glsl @@ -2,7 +2,6 @@ #define PRECISION $precision layout(std430) buffer; -layout(std430) uniform; /* Qualifiers: layout - storage - precision - memory */ diff --git a/aten/src/ATen/native/vulkan/glsl/conv2d_dw_clamp.glsl b/aten/src/ATen/native/vulkan/glsl/conv2d_dw_clamp.glsl index 37a5898b9f10..5155c07669c1 100644 --- a/aten/src/ATen/native/vulkan/glsl/conv2d_dw_clamp.glsl +++ b/aten/src/ATen/native/vulkan/glsl/conv2d_dw_clamp.glsl @@ -1,7 +1,6 @@ #version 450 core #define PRECISION $precision layout(std430) buffer; -layout(std430) uniform; layout(set = 0, rgba16f, binding = 0) writeonly PRECISION uniform image3D uOutput; layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; layout(set = 0, binding = 2) uniform PRECISION sampler3D uKernel; diff --git a/aten/src/ATen/native/vulkan/glsl/conv2d_nogroup_clamp.glsl b/aten/src/ATen/native/vulkan/glsl/conv2d_nogroup_clamp.glsl index b73c58e0f54d..89411284fed4 100644 --- a/aten/src/ATen/native/vulkan/glsl/conv2d_nogroup_clamp.glsl +++ b/aten/src/ATen/native/vulkan/glsl/conv2d_nogroup_clamp.glsl @@ -1,7 +1,6 @@ #version 450 core #define PRECISION $precision layout(std430) buffer; -layout(std430) uniform; layout(set = 0, rgba16f, binding = 0) writeonly PRECISION uniform image3D uOutput; layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; layout(set = 0, binding = 2) uniform PRECISION sampler3D uKernel; diff --git a/aten/src/ATen/native/vulkan/glsl/conv2d_nogroup_clamp_1x.glsl b/aten/src/ATen/native/vulkan/glsl/conv2d_nogroup_clamp_1x.glsl index 5cef89c2727f..8baae9b5fcd5 100644 --- a/aten/src/ATen/native/vulkan/glsl/conv2d_nogroup_clamp_1x.glsl +++ b/aten/src/ATen/native/vulkan/glsl/conv2d_nogroup_clamp_1x.glsl @@ -1,7 +1,6 @@ #version 450 core #define PRECISION $precision layout(std430) buffer; -layout(std430) uniform; layout(set = 0, rgba32f, binding = 0) writeonly PRECISION uniform image3D uOutput; layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; layout(set = 0, binding = 2) uniform PRECISION sampler3D uKernel; diff --git a/aten/src/ATen/native/vulkan/glsl/conv2d_pw.glsl b/aten/src/ATen/native/vulkan/glsl/conv2d_pw.glsl index 48d9f785008b..1355b2c09b05 100644 --- a/aten/src/ATen/native/vulkan/glsl/conv2d_pw.glsl +++ b/aten/src/ATen/native/vulkan/glsl/conv2d_pw.glsl @@ -2,7 +2,6 @@ #define PRECISION $precision layout(std430) buffer; -layout(std430) uniform; /* Qualifiers: layout - storage - precision - memory */ diff --git a/aten/src/ATen/native/vulkan/glsl/image_to_nchw.glsl b/aten/src/ATen/native/vulkan/glsl/image_to_nchw.glsl index d19c370ec9bd..01d653bf06de 100644 --- a/aten/src/ATen/native/vulkan/glsl/image_to_nchw.glsl +++ b/aten/src/ATen/native/vulkan/glsl/image_to_nchw.glsl @@ -2,7 +2,6 @@ #define PRECISION $precision layout(std430) buffer; -layout(std430) uniform; /* Qualifiers: layout - storage - precision - memory */ diff --git a/aten/src/ATen/native/vulkan/glsl/max_pool2d.glsl b/aten/src/ATen/native/vulkan/glsl/max_pool2d.glsl index 948b797a5207..88373605d010 100644 --- a/aten/src/ATen/native/vulkan/glsl/max_pool2d.glsl +++ b/aten/src/ATen/native/vulkan/glsl/max_pool2d.glsl @@ -1,7 +1,6 @@ #version 450 core #define PRECISION $precision layout(std430) buffer; -layout(std430) uniform; layout(set = 0, rgba16f, binding = 0) writeonly PRECISION uniform image3D uOutput; layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; layout(set = 0, binding = 2) uniform constBlock { diff --git a/aten/src/ATen/native/vulkan/glsl/mean.glsl b/aten/src/ATen/native/vulkan/glsl/mean.glsl index 130d716ca9e6..551fd747f103 100644 --- a/aten/src/ATen/native/vulkan/glsl/mean.glsl +++ b/aten/src/ATen/native/vulkan/glsl/mean.glsl @@ -2,7 +2,6 @@ #define PRECISION $precision layout(std430) buffer; -layout(std430) uniform; /* Qualifiers: layout - storage - precision - memory */ diff --git a/aten/src/ATen/native/vulkan/glsl/mean2d.glsl b/aten/src/ATen/native/vulkan/glsl/mean2d.glsl index 266226aa708b..b8d0add329f2 100644 --- a/aten/src/ATen/native/vulkan/glsl/mean2d.glsl +++ b/aten/src/ATen/native/vulkan/glsl/mean2d.glsl @@ -2,7 +2,6 @@ #define PRECISION $precision layout(std430) buffer; -layout(std430) uniform; /* Qualifiers: layout - storage - precision - memory */ diff --git a/aten/src/ATen/native/vulkan/glsl/mm.glsl b/aten/src/ATen/native/vulkan/glsl/mm.glsl index 00ab5f31e6db..157acfe9c074 100644 --- a/aten/src/ATen/native/vulkan/glsl/mm.glsl +++ b/aten/src/ATen/native/vulkan/glsl/mm.glsl @@ -2,7 +2,6 @@ #define PRECISION $precision layout(std430) buffer; -layout(std430) uniform; /* Qualifiers: layout - storage - precision - memory */ diff --git a/aten/src/ATen/native/vulkan/glsl/mul_scalar.glsl b/aten/src/ATen/native/vulkan/glsl/mul_scalar.glsl index d3a98ba30bea..c0ae48fe3883 100644 --- a/aten/src/ATen/native/vulkan/glsl/mul_scalar.glsl +++ b/aten/src/ATen/native/vulkan/glsl/mul_scalar.glsl @@ -2,7 +2,6 @@ #define PRECISION $precision layout(std430) buffer; -layout(std430) uniform; /* Qualifiers: layout - storage - precision - memory */ diff --git a/aten/src/ATen/native/vulkan/glsl/mul_scalar_.glsl b/aten/src/ATen/native/vulkan/glsl/mul_scalar_.glsl index b49252e128cc..f959052879ad 100644 --- a/aten/src/ATen/native/vulkan/glsl/mul_scalar_.glsl +++ b/aten/src/ATen/native/vulkan/glsl/mul_scalar_.glsl @@ -2,7 +2,6 @@ #define PRECISION $precision layout(std430) buffer; -layout(std430) uniform; /* Qualifiers: layout - storage - precision - memory */ diff --git a/aten/src/ATen/native/vulkan/glsl/nchw_to_image.glsl b/aten/src/ATen/native/vulkan/glsl/nchw_to_image.glsl index fb87b5a36918..adbafcbd0438 100644 --- a/aten/src/ATen/native/vulkan/glsl/nchw_to_image.glsl +++ b/aten/src/ATen/native/vulkan/glsl/nchw_to_image.glsl @@ -2,7 +2,6 @@ #define PRECISION $precision layout(std430) buffer; -layout(std430) uniform; /* Qualifiers: layout - storage - precision - memory */ diff --git a/aten/src/ATen/native/vulkan/glsl/permute.glsl b/aten/src/ATen/native/vulkan/glsl/permute.glsl index af8e33588f78..3d1191ff6eea 100644 --- a/aten/src/ATen/native/vulkan/glsl/permute.glsl +++ b/aten/src/ATen/native/vulkan/glsl/permute.glsl @@ -1,6 +1,5 @@ #version 450 core layout(std430) buffer; -layout(std430) uniform; layout(set = 0, binding = 0) writeonly buffer outputBuffer { float data[]; } diff --git a/aten/src/ATen/native/vulkan/glsl/upsample_nearest2d.glsl b/aten/src/ATen/native/vulkan/glsl/upsample_nearest2d.glsl index efb1c5c7fc9a..b4db9b87dacb 100644 --- a/aten/src/ATen/native/vulkan/glsl/upsample_nearest2d.glsl +++ b/aten/src/ATen/native/vulkan/glsl/upsample_nearest2d.glsl @@ -2,7 +2,6 @@ #define PRECISION $precision layout(std430) buffer; -layout(std430) uniform; /* Qualifiers: layout - storage - precision - memory */ diff --git a/aten/src/ATen/templates/TensorBody.h b/aten/src/ATen/templates/TensorBody.h index d42c8c23fe9c..1c0a04a318d0 100644 --- a/aten/src/ATen/templates/TensorBody.h +++ b/aten/src/ATen/templates/TensorBody.h @@ -28,6 +28,7 @@ class Tensor; } namespace c10{ struct TensorOptions; +template class List; } namespace at { struct Generator; diff --git a/aten/src/ATen/test/scalar_test.cpp b/aten/src/ATen/test/scalar_test.cpp index 68c0b4f3f71a..3b7bfb47fe62 100644 --- a/aten/src/ATen/test/scalar_test.cpp +++ b/aten/src/ATen/test/scalar_test.cpp @@ -138,3 +138,23 @@ TEST(TestScalar, TestConj) { ASSERT_EQ(float_scalar.conj().toDouble(), 3.0); ASSERT_EQ(complex_scalar.conj().toComplexDouble(), c10::complex(2.3, -3.5)); } + +TEST(TestScalar, TestEqual) { + ASSERT_FALSE(Scalar(1.0).equal(false)); + ASSERT_FALSE(Scalar(1.0).equal(true)); + ASSERT_FALSE(Scalar(true).equal(1.0)); + ASSERT_TRUE(Scalar(true).equal(true)); + + ASSERT_TRUE(Scalar(c10::complex{2.0, 5.0}).equal(c10::complex{2.0, 5.0})); + ASSERT_TRUE(Scalar(c10::complex{2.0, 0}).equal(2.0)); + ASSERT_TRUE(Scalar(c10::complex{2.0, 0}).equal(2)); + + ASSERT_TRUE(Scalar(2.0).equal(c10::complex{2.0, 0.0})); + ASSERT_FALSE(Scalar(2.0).equal(c10::complex{2.0, 4.0})); + ASSERT_FALSE(Scalar(2.0).equal(3.0)); + ASSERT_TRUE(Scalar(2.0).equal(2)); + + ASSERT_TRUE(Scalar(2).equal(c10::complex{2.0, 0})); + ASSERT_TRUE(Scalar(2).equal(2)); + ASSERT_TRUE(Scalar(2).equal(2.0)); +} diff --git a/c10/CMakeLists.txt b/c10/CMakeLists.txt index 48bceb440954..b175e5bdd6ce 100644 --- a/c10/CMakeLists.txt +++ b/c10/CMakeLists.txt @@ -23,7 +23,7 @@ configure_file( ${CMAKE_BINARY_DIR}/c10/macros/cmake_macros.h) # Note: if you want to add ANY dependency to the c10 library, make sure you -# check with the core PyTorch developers as the dependendency will be +# check with the core PyTorch developers as the dependency will be # transitively passed on to all libraries dependent on PyTorch. file(GLOB C10_SRCS *.cpp diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h index 486272ece92e..58d456b950ed 100644 --- a/c10/core/DispatchKeySet.h +++ b/c10/core/DispatchKeySet.h @@ -124,7 +124,7 @@ class DispatchKeySet final { public: // STL iterator for DispatchKeySet. Iterates through all DispatchKeys in the // set. The iterator is only invalidated by the destruction of the underlying - // DispatchKeySet as the iterator stores a pointer to the raw represenation of + // DispatchKeySet as the iterator stores a pointer to the raw representation of // the DispatchKeySet. class iterator { public: @@ -235,7 +235,7 @@ C10_API DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t); C10_API DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t); // This API exists because we have a use case for checking -// getRuntimeDispatchKeySet(alias).has(DispatchKey::Undefind) +// getRuntimeDispatchKeySet(alias).has(DispatchKey::Undefined) // in OperatorEntry.cpp but we disallow it in has() API. C10_API bool isIncludedInAlias(DispatchKey k, DispatchKey alias); diff --git a/c10/core/MemoryFormat.h b/c10/core/MemoryFormat.h index e25814cd0717..6528f6c8f110 100644 --- a/c10/core/MemoryFormat.h +++ b/c10/core/MemoryFormat.h @@ -98,7 +98,7 @@ inline std::vector get_channels_last_strides_3d(IntArrayRef sizes) { // 1. Please do not combine these helper functions, each helper function handles // exactly one case of sizes + memory_format, by doing this, the strides indices // will be a constant array and we can access it using constant index number, -// the complier will fully unroll the loop on strides indices to gain a better +// the compiler will fully unroll the loop on strides indices to gain a better // performance. // 2. No error check in helper function, caller ensures the correctness of the input // 3. All helper functions have similar comments, only 1st helper function is commented here. @@ -205,7 +205,7 @@ inline bool is_channels_last_strides_3d_s5(const IntArrayRef sizes, const IntArr // a. we identify corner cases where the implementation compromises on. // // By the time accumulated permutation is enabled to replace implicit -// memory_foramt through strides, we should be updating our tests and fix the +// memory_format through strides, we should be updating our tests and fix the // issues in our tests. // // We use Channels Last 2d as an example above. diff --git a/c10/core/Scalar.cpp b/c10/core/Scalar.cpp index 35aa5d60f001..203b544924ec 100644 --- a/c10/core/Scalar.cpp +++ b/c10/core/Scalar.cpp @@ -3,7 +3,7 @@ namespace c10 { Scalar Scalar::operator-() const { - TORCH_CHECK(!isBoolean(), "torch boolean negative, the `-` operator, is not suppported."); + TORCH_CHECK(!isBoolean(), "torch boolean negative, the `-` operator, is not supported."); if (isFloatingPoint()) { return Scalar(-v.d); } else if (isComplex()) { @@ -21,4 +21,14 @@ Scalar Scalar::conj() const { } } +Scalar Scalar::log() const { + if (isComplex()) { + return std::log(v.z); + } else if (isFloatingPoint()) { + return std::log(v.d); + } else { + return std::log(v.i); + } +} + } // namespace c10 diff --git a/c10/core/Scalar.h b/c10/core/Scalar.h index 6151f6d2b150..368228e8202e 100644 --- a/c10/core/Scalar.h +++ b/c10/core/Scalar.h @@ -88,6 +88,45 @@ class C10_API Scalar { Scalar operator-() const; Scalar conj() const; + Scalar log() const; + + template::value, int>::type = 0> + bool equal(T num) const { + if (isComplex()) { + auto val = v.z; + return (val.real() == num) && (val.imag() == T()); + } else if (isFloatingPoint()) { + return v.d == num; + } else if (isIntegral(/*includeBool=*/false)) { + return v.i == num; + } else { + // boolean scalar does not equal to a non boolean value + return false; + } + } + + template::value, int>::type = 0> + bool equal(T num) const { + if (isComplex()) { + return v.z == num; + } else if (isFloatingPoint()) { + return (v.d == num.real()) && (num.imag() == T()); + } else if (isIntegral(/*includeBool=*/false)) { + return (v.i == num.real()) && (num.imag() == T()); + } else { + // boolean scalar does not equal to a non boolean value + return false; + } + } + + bool equal(bool num) const { + if (isBoolean()) { + return static_cast(v.i) == num; + } else { + return false; + } + } + ScalarType type() const { if (isComplex()) { return ScalarType::ComplexDouble; diff --git a/c10/core/Stream.cpp b/c10/core/Stream.cpp index 9a5c838c73fe..1a56c9d68567 100644 --- a/c10/core/Stream.cpp +++ b/c10/core/Stream.cpp @@ -2,7 +2,7 @@ namespace c10 { -// Not very parseable, but I don't know a good compact syntax for streams. +// Not very parsable, but I don't know a good compact syntax for streams. // Feel free to change this into something more compact if needed. std::ostream& operator<<(std::ostream& stream, const Stream& s) { stream << "stream " << s.id() << " on device " << s.device(); diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 3326404e1d07..e7f9c1260263 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -19,7 +19,7 @@ #include // A global boolean variable to control whether we free memory when a Tensor -// is shrinked to a smaller size. As a result, a Tensor is always going to +// is shrunk to a smaller size. As a result, a Tensor is always going to // keep the memory allocated for its maximum capacity reshaped to so far. // // This parameter is respected "upper-case" methods which call Resize() @@ -625,7 +625,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * The API is as follows: * - "new_grad" is a Tensor containing the new value of the gradient that should * be set - * - "self" should reprensent the Tensor whose forward grad is accessed. It is + * - "self" should represent the Tensor whose forward grad is accessed. It is * required when dealing with view. * - "level" allows to specify the level of forward AD nesting for which the * gradient should be set. Note that since levels are not fully supported @@ -1381,7 +1381,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // error in attempt to invoke TypeMeta::ctor() static_assert( std::is_default_constructible::value, - "Tensor can't hold non-default-constructible types"); + "Tensor can't hold non-default-constructable types"); return static_cast(raw_mutable_data(caffe2::TypeMeta::Make())); } diff --git a/c10/core/impl/DeviceGuardImplInterface.h b/c10/core/impl/DeviceGuardImplInterface.h index 2ef02b57d3be..258f8953f4de 100644 --- a/c10/core/impl/DeviceGuardImplInterface.h +++ b/c10/core/impl/DeviceGuardImplInterface.h @@ -126,7 +126,7 @@ struct C10_API DeviceGuardImplInterface { /** * Increments the event's version and enqueues a job with this version * in the stream's work queue. When the stream process that job - * it nofifies all streams waiting on / blocked by that version of the + * it notifies all streams waiting on / blocked by that version of the * event to continue and marks that version as recorded. * */ virtual void record( diff --git a/c10/cuda/CMakeLists.txt b/c10/cuda/CMakeLists.txt index c8fa53df6f02..256fc54b08a1 100644 --- a/c10/cuda/CMakeLists.txt +++ b/c10/cuda/CMakeLists.txt @@ -13,7 +13,7 @@ configure_file( ${CMAKE_BINARY_DIR}/c10/cuda/impl/cuda_cmake_macros.h) # Note: if you want to add ANY dependency to the c10 library, make sure you -# check with the core PyTorch developers as the dependendency will be +# check with the core PyTorch developers as the dependency will be # transitively passed on to all libraries dependent on PyTorch. # Note: if you add a new source file/header, you will need to update diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 0b5d2992538c..493296248e5b 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -62,7 +62,7 @@ constexpr size_t kSmallSize = 1048576; // largest "small" allocation is 1 M constexpr size_t kSmallBuffer = 2097152; // "small" allocations are packed in 2 MiB blocks constexpr size_t kLargeBuffer = 20971520; // "large" allocations may be packed in 20 MiB blocks constexpr size_t kMinLargeAlloc = 10485760; // allocations between 1 and 10 MiB may use kLargeBuffer -constexpr size_t kRoundLarge = 2097152; // round up large allocs to 2 MiB +constexpr size_t kRoundLarge = 2097152; // round up large allocations to 2 MiB typedef std::bitset(StatType::NUM_TYPES)> StatTypes; diff --git a/c10/cuda/CUDAStream.cpp b/c10/cuda/CUDAStream.cpp index 457331f4a00d..d1e290c3f02c 100644 --- a/c10/cuda/CUDAStream.cpp +++ b/c10/cuda/CUDAStream.cpp @@ -60,7 +60,7 @@ static LeakyStreamInternals default_streams[C10_COMPILE_TIME_MAX_GPUS]; // in the pool to be returned when a stream is requested (round-robin fashion // , see the note in CUDAStream.h). // -// unique_ptr is used instead of vector because T might be non-moveable +// unique_ptr is used instead of vector because T might be non-movable // and non-copyable. static std::once_flag device_flags[C10_COMPILE_TIME_MAX_GPUS]; static std::atomic low_priority_counters[C10_COMPILE_TIME_MAX_GPUS]; diff --git a/c10/cuda/CUDAStream.h b/c10/cuda/CUDAStream.h index 41802b3bc9ef..05eddf5ce122 100644 --- a/c10/cuda/CUDAStream.h +++ b/c10/cuda/CUDAStream.h @@ -152,7 +152,7 @@ class C10_CUDA_API CUDAStream { static std::tuple priority_range() { // Note: this returns the range of priority **supported by PyTorch**, not // the range of priority **supported by CUDA**. The former is a subset of - // the latter. Curently PyTorch only supports 0 and -1, which are "low" and + // the latter. Currently PyTorch only supports 0 and -1, which are "low" and // "high" priority. int least_priority, greatest_priority; C10_CUDA_CHECK( diff --git a/c10/macros/Macros.h b/c10/macros/Macros.h index 46ff50621417..5499a7d8b81c 100644 --- a/c10/macros/Macros.h +++ b/c10/macros/Macros.h @@ -316,7 +316,7 @@ __host__ __device__ #define C10_MOBILE 1 #endif // ANDROID / IOS -// Portably determine if a type T is trivially copyable or not. +// Portable determination of whether type T is trivially copyable. // Warning: __has_trivial_copy for GCC may not always detect the non-POD // correctly. For example, T = std::unique_ptr may evaluate to true and be // treated as POD. This can cause unexpected behavior. diff --git a/c10/mobile/CPUCachingAllocator.cpp b/c10/mobile/CPUCachingAllocator.cpp index bde4067d45dc..0114856ca89b 100644 --- a/c10/mobile/CPUCachingAllocator.cpp +++ b/c10/mobile/CPUCachingAllocator.cpp @@ -61,7 +61,7 @@ void CPUCachingAllocator::record_free(void* ptr) { // is being freed outside the scope of this allocator. // At the moment only way to capture this is to have the allocator, // that uses this CachingAllocator as the backing allocator, - // call this function explicity upon freeing memory while + // call this function explicitly upon freeing memory while // outside the scope of caching allocator. // If the memory is freed in some other way, then we will likely // have undefined behavior or page fault. But this can be diff --git a/c10/mobile/CPUCachingAllocator.h b/c10/mobile/CPUCachingAllocator.h index 2f11e6ea8669..c80fee0682eb 100644 --- a/c10/mobile/CPUCachingAllocator.h +++ b/c10/mobile/CPUCachingAllocator.h @@ -26,7 +26,7 @@ * What are the cons? * There are some cons that were observed where use of caching allocator led to * worse performance on some platforms. Reason being that the caching mechanism - * used by this allocator left us worse off compared to the corresonding platform's + * used by this allocator left us worse off compared to the corresponding platform's * tuned memory allocator. In that case it seemed better to not use this allocator. * Note there are some ideas to fix this in the works. * @@ -63,7 +63,7 @@ class C10_API CPUCachingAllocator { // returned the memory to OS via free_cached. // 1.1. Therefore even when the said memory is "freed" via this // allocator (and thus cached), it will continue to stay - // in allocaiton_map_. Furthermore it will also exist in + // in allocation_map_. Furthermore it will also exist in // available_map_. Thus an allocated memory pointer can be in both // allocation_map_ and available_map_ simultaneously. // 2. Memory pointer maybe removed from allocation_map_, when it diff --git a/c10/mobile/CPUProfilingAllocator.cpp b/c10/mobile/CPUProfilingAllocator.cpp index 5f2b28b4b2d0..0118d0a29587 100644 --- a/c10/mobile/CPUProfilingAllocator.cpp +++ b/c10/mobile/CPUProfilingAllocator.cpp @@ -133,7 +133,7 @@ std::vector formulate_greedy_allocation_plan( ska::flat_hash_map::iterator> free_end_offset_to_size_iter; // Upon free end_ptr = offset + size // If end_ptr exists merge freed allocation - // Also find coresponding offset in size_to_offet + // Also find corresponding offset in size_to_offset // Remove that entry and update with new size and offset // If end_ptr does not exist then just insert offset,size // in map and correspondingly size, offset in the other map. @@ -176,7 +176,7 @@ std::vector formulate_greedy_allocation_plan( } allocation_offsets[mem_event.allocation_id] = alloc_offset; } else { - // 1. Check if freed block is adjancent to an existing free block + // 1. Check if freed block is adjacent to an existing free block // at its end boundary. This is done by checking // free_end_offset_to_size_iter. // If we find such a block, remove it and adjust size of @@ -186,7 +186,7 @@ std::vector formulate_greedy_allocation_plan( // free_start_offset_to_size_iter. // If we find such a block, remove it and adjust size of // the block being freed. - // 3. Inser the freed block in map. + // 3. Insert the freed block in map. auto freed_offset = allocation_offsets[mem_event.allocation_id]; auto freed_size = mem_event.size; auto end_offset = freed_offset + freed_size; @@ -223,7 +223,7 @@ std::vector formulate_greedy_allocation_plan( } } TORCH_CHECK(validate_allocation_plan(mem_events, allocation_offsets), - "ProfilingAllocator: Allocation plan invaild."); + "ProfilingAllocator: Allocation plan invalid."); return allocation_offsets; } @@ -394,7 +394,7 @@ CPUProfilingAllocator::~CPUProfilingAllocator() { WithProfileAllocationsGuard::WithProfileAllocationsGuard( AllocationPlan* plan) { - // Nesting of allocation profiling does not seem meanigful. + // Nesting of allocation profiling does not seem meaningful. TORCH_CHECK(allocation_planner == nullptr, "Nesting profiling allocations is not supported."); planner_ = std::make_unique(plan); @@ -409,7 +409,7 @@ WithProfileAllocationsGuard::~WithProfileAllocationsGuard() { WithValidateAllocationPlanGuard::WithValidateAllocationPlanGuard( AllocationPlan* plan, bool* success) { - // Nesting of allocation profiling does not seem meanigful. + // Nesting of allocation profiling does not seem meaningful. TORCH_CHECK(allocation_planner == nullptr, "Nesting profiling allocations is not supported."); planner_ = std::make_unique(plan, true); diff --git a/c10/test/util/bfloat16_test.cpp b/c10/test/util/bfloat16_test.cpp index d08f512053ab..af00bab99c5b 100644 --- a/c10/test/util/bfloat16_test.cpp +++ b/c10/test/util/bfloat16_test.cpp @@ -87,7 +87,7 @@ namespace { } TEST(BFloat16Math, Addition) { - // This test verifies that if only first 7 bits of float's mantisa are + // This test verifies that if only first 7 bits of float's mantissa are // changed after addition, we should have no loss in precision. // input bits @@ -108,8 +108,8 @@ namespace { EXPECT_EQ(res, expected); } - TEST(BFloat16Math, Substraction) { - // This test verifies that if only first 7 bits of float's mantisa are + TEST(BFloat16Math, Subtraction) { + // This test verifies that if only first 7 bits of float's mantissa are // changed after subtraction, we should have no loss in precision. // input bits diff --git a/c10/test/util/intrusive_ptr_test.cpp b/c10/test/util/intrusive_ptr_test.cpp index 2ea283d1a4f0..9df5b004a094 100644 --- a/c10/test/util/intrusive_ptr_test.cpp +++ b/c10/test/util/intrusive_ptr_test.cpp @@ -694,21 +694,21 @@ TEST(IntrusivePtrTest, Equality_Nullptr) { EXPECT_FALSE(var1 != var2); } -TEST(IntrusivePtrTest, Nonequality) { +TEST(IntrusivePtrTest, Inequality) { intrusive_ptr var1 = make_intrusive(); intrusive_ptr var2 = make_intrusive(); EXPECT_TRUE(var1 != var2); EXPECT_FALSE(var1 == var2); } -TEST(IntrusivePtrTest, Nonequality_NullptrLeft) { +TEST(IntrusivePtrTest, Inequality_NullptrLeft) { intrusive_ptr var1; intrusive_ptr var2 = make_intrusive(); EXPECT_TRUE(var1 != var2); EXPECT_FALSE(var1 == var2); } -TEST(IntrusivePtrTest, Nonequality_NullptrRight) { +TEST(IntrusivePtrTest, Inequality_NullptrRight) { intrusive_ptr var1 = make_intrusive(); intrusive_ptr var2; EXPECT_TRUE(var1 != var2); @@ -2487,28 +2487,28 @@ TEST(WeakIntrusivePtrTest, Equality_Invalid) { EXPECT_FALSE(var1 != var2); } -TEST(WeakIntrusivePtrTest, Nonequality) { +TEST(WeakIntrusivePtrTest, Inequality) { IntrusiveAndWeak var1 = make_intrusive(); IntrusiveAndWeak var2 = make_intrusive(); EXPECT_TRUE(var1.weak != var2.weak); EXPECT_FALSE(var1.weak == var2.weak); } -TEST(WeakIntrusivePtrTest, Nonequality_InvalidLeft) { +TEST(WeakIntrusivePtrTest, Inequality_InvalidLeft) { weak_intrusive_ptr var1 = make_invalid_weak(); IntrusiveAndWeak var2 = make_intrusive(); EXPECT_TRUE(var1 != var2.weak); EXPECT_FALSE(var1 == var2.weak); } -TEST(WeakIntrusivePtrTest, Nonequality_InvalidRight) { +TEST(WeakIntrusivePtrTest, Inequality_InvalidRight) { IntrusiveAndWeak var1 = make_intrusive(); weak_intrusive_ptr var2 = make_invalid_weak(); EXPECT_TRUE(var1.weak != var2); EXPECT_FALSE(var1.weak == var2); } -TEST(WeakIntrusivePtrTest, Nonequality_WeakOnly) { +TEST(WeakIntrusivePtrTest, Inequality_WeakOnly) { weak_intrusive_ptr var1 = make_weak_only(); weak_intrusive_ptr var2 = make_weak_only(); EXPECT_TRUE(var1 != var2); diff --git a/c10/util/Bitset.h b/c10/util/Bitset.h index e849563e60fe..964146be05e7 100644 --- a/c10/util/Bitset.h +++ b/c10/util/Bitset.h @@ -64,7 +64,7 @@ struct bitset final { bitset cur = *this; size_t index = cur.find_first_set(); while (0 != index) { - // -1 because find_first_set() is not one-indiced. + // -1 because find_first_set() is not one-indexed. index -= 1; func(index); cur.unset(index); @@ -73,7 +73,7 @@ struct bitset final { } private: - // Return the index of the first set bit. The returned index is one-indiced + // Return the index of the first set bit. The returned index is one-indexed // (i.e. if the very first bit is set, this function returns '1'), and a return // of '0' means that there was no bit set. size_t find_first_set() const { diff --git a/c10/util/Flags.h b/c10/util/Flags.h index 6bfe62507fcd..b4352510c997 100644 --- a/c10/util/Flags.h +++ b/c10/util/Flags.h @@ -4,7 +4,7 @@ /* Commandline flags support for C10. * * This is a portable commandline flags tool for c10, so we can optionally - * choose to use gflags or a lightweighted custom implementation if gflags is + * choose to use gflags or a lightweight custom implementation if gflags is * not possible on a certain platform. If you have gflags installed, set the * macro C10_USE_GFLAGS will seamlessly route everything to gflags. * diff --git a/c10/util/Logging.h b/c10/util/Logging.h index acab3cfecd23..6fa7e93f26d8 100644 --- a/c10/util/Logging.h +++ b/c10/util/Logging.h @@ -284,7 +284,7 @@ BINARY_COMP_HELPER(LessEquals, <=) * Very lightweight logging for the first time API usage. It's beneficial for * tracking of individual functionality usage in larger applications. * - * In order to ensure light-weightness of logging, we utilize static variable + * In order to ensure light-weightedness of logging, we utilize static variable * trick - LogAPIUsage will be invoked only once and further invocations will * just do an atomic check. * diff --git a/c10/util/SmallVector.h b/c10/util/SmallVector.h index 076a1d401065..9b32d8edfe7f 100644 --- a/c10/util/SmallVector.h +++ b/c10/util/SmallVector.h @@ -832,7 +832,7 @@ SmallVectorImpl& SmallVectorImpl::operator=( // If we have to grow to have enough elements, destroy the current elements. // This allows us to avoid copying them during the grow. - // FIXME: don't do this if they're efficiently moveable. + // FIXME: don't do this if they're efficiently movable. if (this->capacity() < RHSSize) { // Destroy current elements. this->destroy_range(this->begin(), this->end()); diff --git a/c10/util/TypeCast.h b/c10/util/TypeCast.h index df15509d7e0f..85513ecc5e2f 100644 --- a/c10/util/TypeCast.h +++ b/c10/util/TypeCast.h @@ -44,7 +44,7 @@ struct static_cast_with_inter_type { // Note: Converting from negative float values to unsigned integer types is // undefined behavior in C++, and current CPU and GPU compilers exhibit // divergent behavior. Casting from negative float values to signed -// integer types and then to unsigned integer types is not undefiend, +// integer types and then to unsigned integer types is not undefined, // however, so this cast improves the consistency of type conversions // to uint8 across compilers. // Further note: Type conversions across compilers still have other undefined diff --git a/c10/util/complex.h b/c10/util/complex.h index 2578da2957ab..d4d5525170af 100644 --- a/c10/util/complex.h +++ b/c10/util/complex.h @@ -61,7 +61,7 @@ namespace c10 { // Since we only support float and double, on will use `complex& operator=(T x)` // - Copy assignment operator and converting assignment operator // - There is no specialization of converting assignment operators, which type is -// convertible is soly depend on whether the scalar type is convertable +// convertible is solely dependent on whether the scalar type is convertible // // In addition to the standard assignment, we also provide assignment operators with std and thrust // diff --git a/c10/util/intrusive_ptr.h b/c10/util/intrusive_ptr.h index 761dd27d6d46..637db95991f2 100644 --- a/c10/util/intrusive_ptr.h +++ b/c10/util/intrusive_ptr.h @@ -700,7 +700,7 @@ class weak_intrusive_ptr final { /** * Takes an owning (but must be weakly referenced) pointer to TTarget* and * creates a weak_intrusive_ptr that takes over ownership. - * Thas means the weakcount is not increased. + * This means that the weakcount is not increased. * This is the counter-part to weak_intrusive_ptr::release() and the pointer * passed in *must* have been created using weak_intrusive_ptr::release(). */ diff --git a/c10/util/typeid.cpp b/c10/util/typeid.cpp index f3fe048b4cca..79c093cbeb31 100644 --- a/c10/util/typeid.cpp +++ b/c10/util/typeid.cpp @@ -60,7 +60,7 @@ CAFFE_KNOWN_TYPE(bool*) CAFFE_KNOWN_TYPE(char*) CAFFE_KNOWN_TYPE(int*) -// For some of the compilers, long is definied separately from int32_t and +// For some of the compilers, long is defined separately from int32_t and // int64_t. As a result we will need to actually define them separately. // It is recommended that one does NOT use long - use int32_t and int64_t // explicitly. Explicit long type annotation may go away in the future. diff --git a/caffe2/contrib/aten/aten_op.cc b/caffe2/contrib/aten/aten_op.cc index 9e7479141ad4..dba68d21c2dd 100644 --- a/caffe2/contrib/aten/aten_op.cc +++ b/caffe2/contrib/aten/aten_op.cc @@ -6,13 +6,17 @@ namespace caffe2 { namespace internal { at::Tensor index_with_uint8_handling( const at::Tensor& self, - at::TensorList indices) { + const torch::List>& indices) { // Support BC only for the simplest case of mask indexing - if (indices.size() == 1 && indices[0].scalar_type() == at::kByte) { - TORCH_WARN( - "Indexing with uint8 mask tensor in ATenOp is now deprecated," - " please use a bool mask instead."); - return at::index(self, {indices[0].to(at::kBool)}); + if (indices.size() == 1) { + c10::optional first = indices[0]; + if (first.has_value() + && first->scalar_type() == at::kByte) { + TORCH_WARN( + "Indexing with uint8 mask tensor in ATenOp is now deprecated," + " please use a bool mask instead."); + return at::index(self, {first->to(at::kBool)}); + } } return at::index(self, indices); } diff --git a/caffe2/contrib/aten/aten_op_template.h b/caffe2/contrib/aten/aten_op_template.h index f3a42dbd8f59..cd1ce7651b48 100644 --- a/caffe2/contrib/aten/aten_op_template.h +++ b/caffe2/contrib/aten/aten_op_template.h @@ -21,7 +21,7 @@ using at::Half; // for AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, ...) namespace internal { TORCH_API at::Tensor index_with_uint8_handling( const at::Tensor& self, - at::TensorList indices); + const torch::List>& indices); } template @@ -86,6 +86,16 @@ class ATenOp : public Operator { std::vector peekSlice(size_t i, size_t len, size_t N) { std::vector results; + results.reserve(len); + for (size_t ii = i; ii < i + len; ++ii) { + results.push_back(peek(ii, N)); + } + return results; + } + + torch::List> peekSliceOptionals(size_t i, size_t len, size_t N) { + torch::List> results; + results.reserve(len); for (size_t ii = i; ii < i + len; ++ii) { results.push_back(peek(ii, N)); } diff --git a/caffe2/contrib/aten/gen_op.py b/caffe2/contrib/aten/gen_op.py index 2a822058bfdf..ba29ab933da9 100755 --- a/caffe2/contrib/aten/gen_op.py +++ b/caffe2/contrib/aten/gen_op.py @@ -68,7 +68,7 @@ def value_has_tensors(v): def value_is_tensor_type(v): - return value_has_tensors(v) and v['dynamic_type'] != 'TensorList' + return value_has_tensors(v) and v['dynamic_type'] not in ['TensorList', 'const c10::List> &'] # for each aten type, how do we handle a return value of that type? @@ -208,7 +208,7 @@ def self_as_first_argument(arguments): def get_num_inputs(o): args = 0 for a in o['arguments']: - if a['type'] == 'TensorList': + if a['type'] in ['TensorList', 'const c10::List> &']: return '*' elif value_has_tensors(a): args += 1 @@ -277,10 +277,10 @@ def emit_assignments(o, env): # e.g. "Float" is at::kFloat assert('Type' in o['method_of']) - static_tensor_inputs = sum(arg['type'] != 'TensorList' and value_is_tensor_type(arg) for arg in o['arguments']) - has_tensorlist = any(arg['type'] == 'TensorList' for arg in o['arguments']) + static_tensor_inputs = sum(arg['type'] not in ['TensorList', 'const c10::List> &'] and value_is_tensor_type(arg) for arg in o['arguments']) + has_tensorlist = any(arg['type'] in ['TensorList', 'const c10::List> &'] for arg in o['arguments']) if has_tensorlist: - tensorlist_idx = [i for i, arg in enumerate(o['arguments']) if arg['type'] == 'TensorList'][0] + tensorlist_idx = [i for i, arg in enumerate(o['arguments']) if arg['type'] in ['TensorList', 'const c10::List> &']][0] real_inputs = 0 for i, arg in enumerate(o['arguments']): @@ -290,10 +290,16 @@ def emit_assignments(o, env): view_length = 'InputSize()' if has_tensorlist and i < tensorlist_idx else static_tensor_inputs if arg['type'] == 'TensorList': # NOTE: do not advance real_inputs here. After this we will - # switch to indexing the "stack" from the end as if we only had + # switch to indexing the "stack" from the end env['statements'].append( 'auto {} = peekSlice({}, InputSize() - {}, InputSize());' .format(arg['name'], real_inputs, static_tensor_inputs)) + elif arg['type'] == 'const c10::List> &': + # NOTE: do not advance real_inputs here. After this we will + # switch to indexing the "stack" from the end + env['statements'].append( + 'auto {} = peekSliceOptionals({}, InputSize() - {}, InputSize());' + .format(arg['name'], real_inputs, static_tensor_inputs)) elif value_is_tensor_type(arg): # load tensor inputs from Caffe2 env['statements'].append( diff --git a/caffe2/contrib/fakelowp/test/test_fusions.py b/caffe2/contrib/fakelowp/test/test_fusions.py index 335159c8318e..3e22d7c5937b 100644 --- a/caffe2/contrib/fakelowp/test/test_fusions.py +++ b/caffe2/contrib/fakelowp/test/test_fusions.py @@ -27,7 +27,7 @@ class Fusions(serial.SerializedTestCase): rand_seed=st.integers(0, 65534), ) @settings(deadline=datetime.timedelta(seconds=10)) - def Skip_test_tanhquantize(self, scale, zp, size, rand_seed): + def test_tanhquantize(self, scale, zp, size, rand_seed): np.random.seed(rand_seed) workspace.ResetWorkspace() diff --git a/caffe2/serialize/inline_container.h b/caffe2/serialize/inline_container.h index a34a6db70115..87c3151bbb76 100644 --- a/caffe2/serialize/inline_container.h +++ b/caffe2/serialize/inline_container.h @@ -12,6 +12,7 @@ #include "caffe2/serialize/istream_adapter.h" #include "caffe2/serialize/read_adapter_interface.h" +#include "caffe2/serialize/versions.h" extern "C" { typedef struct mz_zip_archive mz_zip_archive; @@ -90,68 +91,6 @@ typedef struct mz_zip_archive mz_zip_archive; namespace caffe2 { namespace serialize { -constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L; -constexpr uint64_t kMaxSupportedFileFormatVersion = 0x5L; - -// Versions (i.e. why was the version number bumped?) - -// Note [Dynamic Versions and torch.jit.save vs. torch.save] -// -// Our versioning scheme has a "produced file format version" which -// describes how an archive is to be read. The version written in an archive -// is at least this current produced file format version, but may be greater -// if it includes certain symbols. We refer to these conditional versions -// as "dynamic," since they are identified at runtime. -// -// Dynamic versioning is useful when an operator's semantics are updated. -// When using torch.jit.save we want those semantics to be preserved. If -// we bumped the produced file format version on every change, however, -// then older versions of PyTorch couldn't read even simple archives, like -// a single tensor, from newer versions of PyTorch. Instead, we -// assign dynamic versions to these changes that override the -// produced file format version as needed. That is, when the semantics -// of torch.div changed it was assigned dynamic version 4, and when -// torch.jit.saving modules that use torch.div those archives also have -// (at least) version 4. This prevents earlier versions of PyTorch -// from accidentally performing the wrong kind of division. Modules -// that don't use torch.div or other operators with dynamic versions -// can write the produced file format version, and these programs will -// run as expected on earlier versions of PyTorch. -// -// While torch.jit.save attempts to preserve operator semantics, -// torch.save does not. torch.save is analogous to pickling Python, so -// a function that uses torch.div will have different behavior if torch.saved -// and torch.loaded across PyTorch versions. From a technical perspective, -// torch.save ignores dynamic versioning. - -// 1. Initial version -// 2. Removed op_version_set version numbers -// 3. Added type tags to pickle serialization of container types -// 4. (Dynamic) Stopped integer division using torch.div -// (a versioned symbol preserves the historic behavior of versions 1--3) -// 5. (Dynamic) Stops torch.full inferring a floating point dtype -// when given bool or integer fill values. -constexpr uint64_t kProducedFileFormatVersion = 0x3L; - -// the version we write when the archive contains bytecode. -// It must be higher or eq to kProducedFileFormatVersion. -// Because torchscript changes is likely introduce bytecode change. -// If kProducedFileFormatVersion is increased, kProducedBytecodeVersion -// should be increased too. The relationship is: -// kMaxSupportedFileFormatVersion >= (most likely ==) kProducedBytecodeVersion -// >= kProducedFileFormatVersion -constexpr uint64_t kProducedBytecodeVersion = 0x4L; - -static_assert(kProducedBytecodeVersion >= kProducedFileFormatVersion, - "kProducedBytecodeVersion must be higher or equal to kProducedFileFormatVersion."); - -// Introduce kMinSupportedBytecodeVersion for limited backward compatibility -// support of bytecode. If -// kMinSupportedBytecodeVersion <= model_version <= kProducedBytecodeVersion (in loader), -// we should support this model_version. For example, we provide a wrapper to -// handle an updated operator. -constexpr uint64_t kMinSupportedBytecodeVersion = 0x3L; - class TORCH_API PyTorchStreamReader final { public: explicit PyTorchStreamReader(const std::string& file_name); diff --git a/caffe2/serialize/versions.h b/caffe2/serialize/versions.h new file mode 100644 index 000000000000..4da4b2c50305 --- /dev/null +++ b/caffe2/serialize/versions.h @@ -0,0 +1,68 @@ +#pragma once + +namespace caffe2 { +namespace serialize { + +constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L; +constexpr uint64_t kMaxSupportedFileFormatVersion = 0x5L; + +// Versions (i.e. why was the version number bumped?) + +// Note [Dynamic Versions and torch.jit.save vs. torch.save] +// +// Our versioning scheme has a "produced file format version" which +// describes how an archive is to be read. The version written in an archive +// is at least this current produced file format version, but may be greater +// if it includes certain symbols. We refer to these conditional versions +// as "dynamic," since they are identified at runtime. +// +// Dynamic versioning is useful when an operator's semantics are updated. +// When using torch.jit.save we want those semantics to be preserved. If +// we bumped the produced file format version on every change, however, +// then older versions of PyTorch couldn't read even simple archives, like +// a single tensor, from newer versions of PyTorch. Instead, we +// assign dynamic versions to these changes that override the +// produced file format version as needed. That is, when the semantics +// of torch.div changed it was assigned dynamic version 4, and when +// torch.jit.saving modules that use torch.div those archives also have +// (at least) version 4. This prevents earlier versions of PyTorch +// from accidentally performing the wrong kind of division. Modules +// that don't use torch.div or other operators with dynamic versions +// can write the produced file format version, and these programs will +// run as expected on earlier versions of PyTorch. +// +// While torch.jit.save attempts to preserve operator semantics, +// torch.save does not. torch.save is analogous to pickling Python, so +// a function that uses torch.div will have different behavior if torch.saved +// and torch.loaded across PyTorch versions. From a technical perspective, +// torch.save ignores dynamic versioning. + +// 1. Initial version +// 2. Removed op_version_set version numbers +// 3. Added type tags to pickle serialization of container types +// 4. (Dynamic) Stopped integer division using torch.div +// (a versioned symbol preserves the historic behavior of versions 1--3) +// 5. (Dynamic) Stops torch.full inferring a floating point dtype +// when given bool or integer fill values. +constexpr uint64_t kProducedFileFormatVersion = 0x3L; + +// the version we write when the archive contains bytecode. +// It must be higher or eq to kProducedFileFormatVersion. +// Because torchscript changes is likely introduce bytecode change. +// If kProducedFileFormatVersion is increased, kProducedBytecodeVersion +// should be increased too. The relationship is: +// kMaxSupportedFileFormatVersion >= (most likely ==) kProducedBytecodeVersion +// >= kProducedFileFormatVersion +constexpr uint64_t kProducedBytecodeVersion = 0x4L; + +static_assert(kProducedBytecodeVersion >= kProducedFileFormatVersion, + "kProducedBytecodeVersion must be higher or equal to kProducedFileFormatVersion."); + +// Introduce kMinSupportedBytecodeVersion for limited backward compatibility +// support of bytecode. If +// kMinSupportedBytecodeVersion <= model_version <= kProducedBytecodeVersion (in loader), +// we should support this model_version. For example, we provide a wrapper to +// handle an updated operator. +constexpr uint64_t kMinSupportedBytecodeVersion = 0x3L; +} // namespace serialize +} // namespace caffe2 diff --git a/docs/source/quantization.rst b/docs/source/quantization.rst index a389de60416a..1cac90ffab86 100644 --- a/docs/source/quantization.rst +++ b/docs/source/quantization.rst @@ -530,6 +530,71 @@ Best Practices ``fbgemm`` backend. This argument prevents overflow on some int8 instructions by reducing the range of quantized data type by 1 bit. +Common Errors +--------------------------------------- + +Passing a non-quantized Tensor into a quantized kernel +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +If you see an error similar to:: + + RuntimeError: Could not run 'quantized::some_operator' with arguments from the 'CPU' backend... + +This means that you are trying to pass a non-quantized Tensor to a quantized +kernel. A common workaround is to use ``torch.quantization.QuantStub`` to +quantize the tensor. This needs to be done manually in Eager mode quantization. +An e2e example:: + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.quant = torch.quantization.QuantStub() + self.conv = torch.nn.Conv2d(1, 1, 1) + + def forward(self, x): + # during the convert step, this will be replaced with a + # `quantize_per_tensor` call + x = self.quant(x) + x = self.conv(x) + return x + +Passing a quantized Tensor into a non-quantized kernel +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +If you see an error similar to:: + + RuntimeError: Could not run 'aten::thnn_conv2d_forward' with arguments from the 'QuantizedCPU' backend. + +This means that you are trying to pass a quantized Tensor to a non-quantized +kernel. A common workaround is to use ``torch.quantization.DeQuantStub`` to +dequantize the tensor. This needs to be done manually in Eager mode quantization. +An e2e example:: + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.quant = torch.quantization.QuantStub() + self.conv1 = torch.nn.Conv2d(1, 1, 1) + # this module will not be quantized (see `qconfig = None` logic below) + self.conv2 = torch.nn.Conv2d(1, 1, 1) + self.dequant = torch.quantization.DeQuantStub() + + def forward(self, x): + # during the convert step, this will be replaced with a + # `quantize_per_tensor` call + x = self.quant(x) + x = self.conv1(x) + # during the convert step, this will be replaced with a + # `dequantize` call + x = self.dequant(x) + x = self.conv2(x) + return x + + m = M() + m.qconfig = some_qconfig + # turn off quantization for conv2 + m.conv2.qconfig = None + Modules that provide quantization functions and classes ------------------------------------------------------- diff --git a/setup.py b/setup.py index 01f173d6825b..8289b57e93be 100644 --- a/setup.py +++ b/setup.py @@ -892,6 +892,7 @@ def print_box(msg): 'include/torch/csrc/jit/serialization/*.h', 'include/torch/csrc/jit/python/*.h', 'include/torch/csrc/jit/testing/*.h', + 'include/torch/csrc/jit/tensorexpr/*.h', 'include/torch/csrc/onnx/*.h', 'include/torch/csrc/utils/*.h', 'include/pybind11/*.h', diff --git a/test/cpp/api/autograd.cpp b/test/cpp/api/autograd.cpp index e4bb96ece6fb..3f79c771c2be 100644 --- a/test/cpp/api/autograd.cpp +++ b/test/cpp/api/autograd.cpp @@ -175,7 +175,7 @@ TEST(AutogradAPITests, AnomalyMode) { auto y = x.pow(1.5); auto gr = grad({y}, {x}, {}, /*retain_graph=*/true, /*create_backward=*/true); - ASSERT_THROWS_WITH(grad({gr[0]}, {x});, "returned nan"); + ASSERT_THROWS_WITH(grad({gr[0]}, {x}, {torch::tensor({0.0})});, "returned nan"); auto msgs = warnings.messages(); ASSERT_EQ(msgs.size(), 2); ASSERT_TRUE( diff --git a/test/cpp/api/tensor_indexing.cpp b/test/cpp/api/tensor_indexing.cpp index efb153fbf481..03600c5c882e 100644 --- a/test/cpp/api/tensor_indexing.cpp +++ b/test/cpp/api/tensor_indexing.cpp @@ -83,27 +83,27 @@ TEST(TensorIndexingTest, TestNoIndices) { ASSERT_THROWS_WITH(tensor.index_put_(indices, value), "Passing an empty index list to Tensor::index_put_() is not valid syntax"); } -TEST(TensorIndexingTest, TestAdvancedIndexingWithArrayRefOfTensor) { +TEST(TensorIndexingTest, TestAdvancedIndexingWithListOfTensor) { { torch::Tensor tensor = torch::randn({20, 20}); torch::Tensor index = torch::arange(10, torch::kLong).cpu(); - torch::Tensor result_with_array_ref = tensor.index(at::ArrayRef({index})); + torch::Tensor result = at::index(tensor, {index}); torch::Tensor result_with_init_list = tensor.index({index}); - ASSERT_TRUE(result_with_array_ref.equal(result_with_init_list)); + ASSERT_TRUE(result.equal(result_with_init_list)); } { torch::Tensor tensor = torch::randn({20, 20}); torch::Tensor index = torch::arange(10, torch::kLong).cpu(); - torch::Tensor result_with_array_ref = tensor.index_put_(at::ArrayRef({index}), torch::ones({20})); + torch::Tensor result = at::index_put_(tensor, {index}, torch::ones({20})); torch::Tensor result_with_init_list = tensor.index_put_({index}, torch::ones({20})); - ASSERT_TRUE(result_with_array_ref.equal(result_with_init_list)); + ASSERT_TRUE(result.equal(result_with_init_list)); } { torch::Tensor tensor = torch::randn({20, 20}); torch::Tensor index = torch::arange(10, torch::kLong).cpu(); - torch::Tensor result_with_array_ref = tensor.index_put_(at::ArrayRef({index}), torch::ones({1, 20})); + torch::Tensor result = at::index_put_(tensor, {index}, torch::ones({1, 20})); torch::Tensor result_with_init_list = tensor.index_put_({index}, torch::ones({1, 20})); - ASSERT_TRUE(result_with_array_ref.equal(result_with_init_list)); + ASSERT_TRUE(result.equal(result_with_init_list)); } } @@ -173,7 +173,7 @@ TEST(TensorIndexingTest, TestBoolIndices) { TEST(TensorIndexingTest, TestBoolIndicesAccumulate) { auto mask = torch::zeros({10}, torch::kBool); auto y = torch::ones({10, 10}); - y.index_put_({mask}, y.index({mask}), /*accumulate=*/true); + y.index_put_({mask}, {y.index({mask})}, /*accumulate=*/true); assert_tensor_equal(y, torch::ones({10, 10})); } diff --git a/test/cpp/jit/test_save_load.cpp b/test/cpp/jit/test_save_load.cpp index 2e59358b4e00..e102a6ff767c 100644 --- a/test/cpp/jit/test_save_load.cpp +++ b/test/cpp/jit/test_save_load.cpp @@ -120,5 +120,33 @@ TEST(SerializationTest, TypeTags) { } } +TEST(SerializationTest, TestJitStream_CUDA) { + torch::jit::Module model; + std::vector inputs; + // Deserialize the ScriptModule from a file using torch::jit::load(). + // Load the scripted model. This should have been generated by tests_setup.py + // Refer: TorchSaveJitStream_CUDA in test/cpp/jit/tests_setup.py + model = torch::jit::load("saved_stream_model.pt"); + + auto output = model.forward(inputs); + auto list_of_elements = output.toTuple()->elements(); + auto is_stream_s = list_of_elements[0].toBool(); + + // a,b: These are the two input tensors + // c: This is output tensor generated by the operation torch.cat(a,b) + auto a = list_of_elements[1].toTensor(); + auto b = list_of_elements[2].toTensor(); + auto c = list_of_elements[3].toTensor(); + // op: this is used to verify if the cat operation produced the same results + // as that on the GPU with torch.cat + auto op = at::cat({a, b}, 0); + + // Check if the stream is set + ASSERT_TRUE(is_stream_s); + // Check if the sizes of the outputs (op and c) is same on the GPU and CPU + ASSERT_EQ(op.sizes(), c.sizes()); + // Check if both the output tensors are equal + ASSERT_TRUE(op.equal(c)); +} } // namespace jit } // namespace torch diff --git a/test/cpp/jit/tests_setup.py b/test/cpp/jit/tests_setup.py index 68871d1c21d2..928a06d9b5a0 100644 --- a/test/cpp/jit/tests_setup.py +++ b/test/cpp/jit/tests_setup.py @@ -63,11 +63,38 @@ def setup(self): torch.save(value, self.path, _use_new_zipfile_serialization=False) +class TorchSaveJitStream_CUDA(FileSetup): + path = 'saved_stream_model.pt' + + def setup(self): + if not torch.cuda.is_available(): + return + + class Model(torch.nn.Module): + def forward(self): + device_index = torch.cuda._current_device() + s = torch.jit.cuda.Stream(device_index, 0) + a = torch.rand(3, 4, device="cuda") + b = torch.rand(3, 4, device="cuda") + + with torch.jit.cuda.stream(s): + is_stream_s = torch.cuda.current_stream(s.device_index()).id() == s.id() + c = torch.cat((a, b), 0).to("cuda") + s.synchronize() + return is_stream_s, a, b, c + + model = Model() + + # Script the model and save + script_model = torch.jit.script(model) + torch.jit.save(script_model, self.path) + tests = [ EvalModeForLoadedModule(), SerializationInterop(), TorchSaveError(), + TorchSaveJitStream_CUDA() ] def setup(): diff --git a/test/jit/test_cuda.py b/test/jit/test_cuda.py new file mode 100644 index 000000000000..f7af8e3a2efc --- /dev/null +++ b/test/jit/test_cuda.py @@ -0,0 +1,476 @@ +import os +import sys +import gc +import unittest + +import torch +from typing import NamedTuple +from torch.testing._internal.jit_utils import JitTestCase +from torch.testing._internal.common_utils import skipIfRocm, skipCUDANonDefaultStreamIf + +# Make the helper files in test/ importable +pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +sys.path.append(pytorch_test_dir) + +# Check if GPU is available +TEST_CUDA = torch.cuda.is_available() +# Check if multiple GPU's are available +TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2 + +# If GPU is not available, then do not run the tests +if not TEST_CUDA: + print('CUDA not available, skipping tests', file=sys.stderr) + JitTestCase = object # noqa: F811 + +TEST_LARGE_TENSOR = TEST_CUDA + +# If GPU is available, then initialize the cuda context and check +# if there is memory available to allocate for LARGE Tensors. +if TEST_CUDA: + torch.ones(1).cuda() # initialize cuda context + TEST_LARGE_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 5e9 + +if __name__ == "__main__": + raise RuntimeError( + "This test file is not meant to be run directly, use:\n\n" + "\tpython test/test_jit.py TESTNAME\n\n" + "instead." + ) + +class TestCUDA(JitTestCase): + """ + A suite of tests for the CUDA API in TorchScript. + """ + def setUp(self): + super(TestCUDA, self).setUp() + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + super(TestCUDA, self).tearDown() + + @skipIfRocm + @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") + def test_current_stream(self): + # Test current stream on the device and check if the stream device index + # matches with the device ID + @torch.jit.script + def fn(): + device_index = torch.cuda._current_device() + s0 = torch.cuda.current_stream(device_index) + s1 = torch.cuda.current_stream(1) + s2 = torch.cuda.current_stream(0) + + return s0.device_index(), s1.device_index(), s2.device_index() + + d0, d1, d2 = fn() + + # By default, the current device ID is 0. + self.assertEqual(0, d0) + self.assertEqual(1, d1) + self.assertEqual(0, d2) + self.assertEqual(d0, d2) + + @skipIfRocm + @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") + @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory") + @skipCUDANonDefaultStreamIf(True) + def test_streams_and_events(self): + # This test checks for the default stream ID is set to 0 on the device + @torch.jit.script + def test_default_streams(): + s0 = torch.cuda.default_stream(0) + s1 = torch.cuda.default_stream(1) + + d = torch.device('cuda:1') + + # Check the current stream id and default id are same + # on the current device. The current device id by default is 0 + s2 = torch.cuda.current_stream(0) + check_s2 = s2.id() == s0.id() + check_d0 = torch.cuda._current_device() == s2.device_index() + + # Set the current device to d1 and check if the stream + # has been set to the default stream on d1 + with torch.jit.cuda.device(d): + s3 = torch.cuda.current_stream(1) + check_s3 = s3.id() == s1.id() + check_d1 = torch.cuda._current_device() == s3.device_index() + + # Check if the current device was reset to 0 + is_device_d0 = torch.cuda._current_device() == s2.device_index() + + return s0.device_index(), s1.device_index(), check_s2, check_s3, check_d0, check_d1, is_device_d0 + + d0, d1, check_s2, check_s3, check_d0, check_d1, is_device_d0 = test_default_streams() + + self.assertEqual(d0, 0) + self.assertEqual(d1, 1) + self.assertTrue(check_s2) + self.assertTrue(check_s3) + self.assertTrue(check_d0) + self.assertTrue(check_d1) + self.assertTrue(is_device_d0) + + # This test checks if the Stream Context manager is a no op + # when the stream is none for `with torch.jit.cuda.stream` + @torch.jit.script + def test_set_none_stream(): + device_index = torch.cuda._current_device() + current_stream = torch.cuda.current_stream(device_index) + default_stream = torch.cuda.default_stream(device_index) + + # When stream is none, check if this operation is a no-op + with torch.jit.cuda.stream(None): + cur_device_index = torch.cuda._current_device() + is_device_index_same = cur_device_index == device_index + is_current_stream_same = torch.cuda.current_stream(cur_device_index).id() == current_stream.id() + is_default_stream_same = torch.cuda.default_stream(device_index).id() == default_stream.id() + + # Check if the device index, current stream and default streams have not changed + are_streams_same = is_device_index_same and is_current_stream_same and is_default_stream_same + return are_streams_same + self.assertTrue(test_set_none_stream()) + + # This test checks if the Device Context manager is a no op + # when the device is none for `with torch.jit.cuda.device` + @torch.jit.script + def test_set_device_none(): + device_index = torch.cuda._current_device() + # When device is none, check if this operation is a no-op + with torch.jit.cuda.device(None): + # Check if the current device is the same + is_device_same = torch.cuda._current_device() == device_index + return is_device_same + self.assertTrue(test_set_device_none()) + + # Check if a CUDA JIT stream is created + # on the _current_device + @torch.jit.script + def test_simple_stream(): + device_index = torch.cuda._current_device() + s = torch.jit.cuda.Stream(device_index, 0) + return device_index == s.device_index() + + self.assertTrue(test_simple_stream(), "Could not create Stream!") + + # Class used to store results for the test: test_get_stream. + class Result(NamedTuple): + t1 : torch.Tensor + t2 : torch.Tensor + is_current_and_default_stream_same : bool + is_default_and_user_stream_not_same : bool + is_stream_set : bool + is_stream_reset : bool + default_stream_query : bool + default_stream_id : int + user_stream_id : int + + # The test aims at checking different stream proporties. + @torch.jit.script + def test_get_stream(): + device_index = torch.cuda._current_device() + current_stream = torch.cuda.current_stream(device_index) + default_stream = torch.cuda.default_stream(device_index) + user_stream = torch.jit.cuda.Stream(device_index, 0) + + # Check if the current and default streams are the same on the device + is_current_and_default_stream_same = current_stream.id() == default_stream.id() + # Check if user stream and default stream are not the same on the device + is_default_and_user_stream_not_same = default_stream.id() != user_stream.id() + + with torch.jit.cuda.stream(user_stream): + is_stream_set = torch.cuda.current_stream(device_index).id() == user_stream.id() + + # Check if the stream was reset to current_stream + is_stream_reset = torch.cuda.current_stream(device_index).id() == current_stream.id() + + tensor1 = torch.rand(10000, 10000, device="cuda") + tensor2 = torch.mm(tensor1, tensor1).to("cuda") + default_stream.synchronize() + default_stream_query = default_stream.query() + + # Capture all the results in the class Result + res = Result( + tensor1, tensor2, is_current_and_default_stream_same, + is_default_and_user_stream_not_same, is_stream_set, + is_stream_reset, default_stream_query, default_stream.id(), user_stream.id()) + return res + + result = test_get_stream() + + self.assertEqual(torch.matmul(result.t1, result.t1), result.t2) + self.assertTrue(result.is_current_and_default_stream_same) + self.assertTrue(result.is_default_and_user_stream_not_same) + self.assertTrue(result.is_stream_set) + self.assertTrue(result.is_stream_reset) + self.assertTrue(result.default_stream_query) + self.assertEqual(result.default_stream_id, 0) # Check if the default stream ID is always 0 + self.assertNotEqual(result.user_stream_id, 0) # Check if the user stream is always non zero + + # Test the stream context manager. This test checks if the stream is switched + # to the user stream on using the stream context manager. + @torch.jit.script + def test_stream_context(): + device_index = torch.cuda._current_device() + current_stream = torch.cuda.current_stream(device_index) + user_stream = torch.jit.cuda.Stream(device_index, 0) + A = torch.rand(1000, 1000, device="cuda") + + with torch.jit.cuda.stream(user_stream): + check = torch.cuda.current_stream(device_index).id() == user_stream.id() + B = torch.mm(A, A).to("cuda") + # Wait for B to be computed + user_stream.synchronize() + # Check if the stream has been reset on the current device + is_stream_reset = torch.cuda.current_stream(device_index).id() == current_stream.id() + + return A, B, check, is_stream_reset + + A, B, is_stream_set, is_stream_reset = test_stream_context() + self.assertEqual(torch.matmul(A, A), B) + self.assertTrue(is_stream_set, "Error: Current stream was not set to user stream!") + self.assertTrue(is_stream_reset, "Error: The stream was not restored to previous stream!") + + # Test multiple nested streams. Check if the operations are computed as expected on the streams + # This test has been adapted from the eager mode tests available at test/test_cuda.py + @torch.jit.script + def test_multiple_stream(): + prev_device_index = torch.cuda._current_device() + prev_current_stream = torch.cuda.current_stream(prev_device_index) + s1 = torch.jit.cuda.Stream(0, 0) + s2 = torch.jit.cuda.Stream(1, 0) + + A = torch.rand(1000, 1000, device="cuda") + B = torch.rand(1000, 1000, device="cuda") + with torch.jit.cuda.stream(s1): + C = torch.mm(A, A).to("cuda") + # Check if the stream and device have been set to s1 + is_stream_s1 = torch.cuda.current_stream(s1.device_index()).id() == s1.id() + is_device_s1 = torch.cuda._current_device() == s1.device_index() + with torch.jit.cuda.stream(s2): + # Check if the stream and device have been set to s2 + is_stream_s2 = torch.cuda.current_stream(s2.device_index()).id() == s2.id() + is_device_s2 = torch.cuda._current_device() == s2.device_index() + D = torch.mm(B, B).to("cuda") + # Check if the stream and device have been set to s1 + is_stream_s1_after = torch.cuda.current_stream(s1.device_index()).id() == s1.id() + is_device_s1_after = torch.cuda._current_device() == s1.device_index() + # Wait for D to be computed + s2.synchronize() + # Wait for C to be computed on S1 + s1.synchronize() + + # Check if the stream and device has been restored to previous stream and device + is_device_current = torch.cuda._current_device() == prev_device_index + is_stream_current = torch.cuda.current_stream(prev_device_index).id() == prev_current_stream.id() + + check_stream = is_stream_s1 and is_stream_s2 and is_stream_s1_after and is_stream_current + check_device = is_device_s1 and is_device_s2 and is_device_s1_after and is_device_current + return A, B, C, D, check_stream, check_device + A, B, C, D, check_stream, check_device = test_multiple_stream() + + self.assertEqual(torch.matmul(A, A), C) + self.assertEqual(torch.matmul(B, B), D) + self.assertTrue(check_stream) + self.assertTrue(check_device) + + # Test multiple streams waiting on each other for the operations to be completed. + @torch.jit.script + def test_data_dependency_between_streams(): + device_index = torch.cuda._current_device() + prev_current_stream = torch.cuda.current_stream(device_index) + s1 = torch.jit.cuda.Stream(0, 0) + s2 = torch.jit.cuda.Stream(0, 0) + event = torch.jit.cuda.Event(False, False, False) + + A = torch.rand(1000, 1000, device="cuda") + with torch.jit.cuda.stream(s1): + is_stream_s1 = torch.cuda.current_stream(device_index).id() == s1.id() + B = torch.mm(A, A).to("cuda") + s1.record_event(event) + # Check if the current_stream is reset + is_current_stream_1 = torch.cuda.current_stream(device_index).id() == prev_current_stream.id() + # Wait for ops on s1 to be computed + s2.wait_event(event) + with torch.jit.cuda.stream(s2): + is_stream_s2 = torch.cuda.current_stream(device_index).id() == s2.id() + C = torch.mm(B, B).to("cuda") + # Wait for C to be computed + s2.synchronize() + # Check if the current_stream is reset + is_current_stream_2 = torch.cuda.current_stream(device_index).id() == prev_current_stream.id() + + check_stream = is_current_stream_1 and is_current_stream_2 and is_stream_s1 and is_stream_s2 + return A, B, C, check_stream + + A, B, C, check_stream = test_data_dependency_between_streams() + self.assertEqual(torch.matmul(A, A), B) + self.assertEqual(torch.matmul(B, B), C) + self.assertTrue(check_stream) + + # Test a simple CUDA event. Test if the CUDA event was created successfully + @torch.jit.script + def test_simple_event(): + e = torch.jit.cuda.Event(True, False, False) + return e is not None + self.assertTrue(test_simple_event(), "Could not create CUDA Event!") + + # Record the CUDA event for operation torch.mm on the current stream + # and then test if the elapsed time is greater than 0. This test is also + # an adaption from eager mdoe CUDA tests available at test/test_cuda.py + @torch.jit.script + def test_event(): + device_index = torch.cuda._current_device() + stream = torch.cuda.current_stream(device_index) + event = torch.jit.cuda.Event(True, False, False) + is_true_event_query = event.query() + start_event = torch.jit.cuda.Event(True, False, False) + stream.record_event(start_event) + tensor1 = torch.rand(1000000000, 1000000000, device="cuda") + tensor2 = torch.mm(tensor1, tensor1).to("cuda") + stream.record_event(event) + event.synchronize() + is_again_true_event_query = event.query() + + if not (is_true_event_query and is_again_true_event_query): + return -1.0 + return start_event.elapsed_time(event) + + self.assertGreater(test_event(), 0) + + # Check for stream synchronization , when a large tensor multiplication is + # computed on the stream. The stream.query should be true once the synchroniztion is done + @torch.jit.script + def test_stream_synchronize() -> float: + device_index = torch.cuda._current_device() + s = torch.jit.cuda.Stream(device_index, 0) + e_tik = torch.jit.cuda.Event(True, False, False) + e_tok = torch.jit.cuda.Event(True, False, False) + + e_tik.record(s) + tensor1 = torch.rand(1000000000, 1000000000, device="cuda") + with torch.jit.cuda.stream(s): + tensor2 = torch.mm(tensor1, tensor1).to("cuda") + s.synchronize() + e_tok.record(s) + e_tok.synchronize() + + if not s.query(): + return -1.0 + + # not necessary to check e_tik and e_tok, as elapsed_time would throw + # exception if otherwise. + return e_tik.elapsed_time(e_tok) + self.assertGreater(test_stream_synchronize(), 0) + + # Test event synchronization for the event that records a stream doing + # a large tensor multiplication. Check if the elapsed time is greater than 0 + # and the stream.query evaluates to true. + @torch.jit.script + def test_event_synchronize() -> float: + device_index = torch.cuda._current_device() + s = torch.jit.cuda.Stream(device_index, 0) + e_tik = torch.jit.cuda.Event(True, False, False) + e_tok = torch.jit.cuda.Event(True, False, False) + + e_tik.record(s) + tensor1 = torch.rand(1000000000, 1000000000, device="cuda") + with torch.jit.cuda.stream(s): + tensor = torch.mm(tensor1, tensor1).to("cuda") + s.record_event(e_tok) + e_tok.synchronize() + s.synchronize() + + if not s.query(): + return -1.0 + + # not necessary to check e_tik and e_tok, as elapsed_time would throw + # exception if otherwise. + return e_tik.elapsed_time(e_tok) + + self.assertGreater(test_event_synchronize(), 0) + + # Test for event wait. Check if event waits for the all the operations on + # the stream to be done. Check for synchronizations and query on the streams + # and events. This test is adapted from eager mode tests for CUDA. Please refer + # test/test_cuda.py + @torch.jit.script + def test_event_wait() -> float: + device_index = torch.cuda._current_device() + s0 = torch.cuda.current_stream(device_index) + s1 = torch.jit.cuda.Stream(device_index, 0) + e_tik = torch.jit.cuda.Event(True, True, False) + e_tok = torch.jit.cuda.Event(True, True, False) + + e_tik.record(s0) + tensor1 = torch.rand(1000000000, 1000000000, device="cuda") + with torch.jit.cuda.stream(s0): + tensor2 = torch.mm(tensor1, tensor1).cuda() + e_sync = torch.jit.cuda.Event(True, False, False) + e_sync.record(torch.cuda.current_stream(device_index)) + e_sync.wait(s1) + with torch.jit.cuda.stream(s1): + tensor3 = torch.rand(1000000000, 1000000000, device="cuda") + tensor4 = torch.mm(tensor3, tensor3).cuda() + s1.synchronize() + e_tok.record(torch.cuda.current_stream(device_index)) + e_tok.synchronize() + s0.synchronize() + + if not s0.query() or not s1.query() or not e_sync.query(): + return -1.0 + + # not necessary to check e_tik and e_tok, as elapsed_time would throw + # exception if otherwise. + return e_tik.elapsed_time(e_tok) + self.assertGreater(test_event_wait(), 0) + + # Test for stream wait_event. Checks if the stream waits on the event + @torch.jit.script + def test_wait_event(): + d1 = torch.device('cuda:1') + + with torch.jit.cuda.device(d1): + s0 = torch.cuda.current_stream(1) + tensor1 = torch.rand(1000000000, 1000000000, device="cuda") + tensor2 = torch.mm(tensor1, tensor1).to("cuda") + e0 = torch.jit.cuda.Event(False, False, False) + s0.record_event(e0) + + s1 = torch.cuda.current_stream(0) + s1.wait_event(e0) + s1.synchronize() + + return e0.query() and s0.query() and s1.query() + self.assertTrue(test_wait_event()) + + # Test if a scripted module with cuda streams can be saved, loaded and executed + def test_save_load(self): + class Model(torch.nn.Module): + def forward(self): + device_index = torch.cuda._current_device() + s = torch.jit.cuda.Stream(device_index, 0) + a = torch.rand(3, 4, device="cuda") + b = torch.rand(3, 4, device="cuda") + + with torch.jit.cuda.stream(s): + is_stream_s = torch.cuda.current_stream(s.device_index()).id() == s.id() + c = torch.cat((a, b), 0).cuda() + s.synchronize() + return is_stream_s, a, b, c + + model = Model() + + # Script the model and save + script_model = torch.jit.script(model) + is_stream_s, a, b, c = script_model() + # Verify if the output is correct + self.assertTrue(is_stream_s) + self.assertEqual(torch.cat((a, b), 0), c) + + # Save and load scripted model + load_model = self.getExportImportCopy(script_model) + is_stream_s, a_load, b_load, c_load = load_model() + self.assertTrue(is_stream_s) + self.assertEqual(torch.cat((a_load, b_load), 0), c_load) diff --git a/test/jit/test_tracer.py b/test/jit/test_tracer.py index 059f59ff8702..366ca1af69e6 100644 --- a/test/jit/test_tracer.py +++ b/test/jit/test_tracer.py @@ -15,7 +15,7 @@ sys.path.append(pytorch_test_dir) from torch.testing._internal.common_utils import suppress_warnings, \ skipIfCompiledWithoutNumpy, enable_profiling_mode_for_profiling_tests, \ - IS_SANDCASTLE, IS_WINDOWS + IS_SANDCASTLE, TemporaryFileName from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, \ _tmp_donotuse_dont_inline_everything, _trace, RUN_CUDA, RUN_CUDA_MULTI_GPU from torch.testing._internal.common_cuda import with_tf32_off @@ -25,7 +25,6 @@ # Standard library from collections import namedtuple from itertools import chain -import tempfile from typing import Dict import warnings @@ -1215,15 +1214,14 @@ def foo(x): self.run_pass('inline', traced_tensor_size.graph) FileCheck().check("prim::device").run(traced_tensor_size.graph) - @unittest.skipIf(IS_WINDOWS, "temp file name on windows") def test_trace_save(self): def fn(x): return x + 2 def check(func): - with tempfile.NamedTemporaryFile() as f: - func.save(f.name) - loaded = torch.jit.load(f.name) + with TemporaryFileName() as fname: + func.save(fname) + loaded = torch.jit.load(fname) input = torch.randn(2, 2) self.assertEqual(func(input), loaded(input)) diff --git a/test/quantization/test_quantized_op.py b/test/quantization/test_quantized_op.py index c676ccc0f793..a192eddca234 100644 --- a/test/quantization/test_quantized_op.py +++ b/test/quantization/test_quantized_op.py @@ -23,7 +23,7 @@ from torch.testing._internal.common_utils import IS_PPC, TEST_WITH_UBSAN, IS_MACOS from torch.testing._internal.common_quantization import skipIfNoFBGEMM from torch.testing._internal.common_quantized import _quantize, _dequantize, _calculate_dynamic_qparams, \ - override_quantized_engine, supported_qengines, override_qengines + override_quantized_engine, supported_qengines, override_qengines, _snr from torch.testing._internal.common_quantized import qengine_is_qnnpack from torch.quantization import PerChannelMinMaxObserver @@ -2314,6 +2314,87 @@ def test_advanced_indexing(self): torch.quantize_per_tensor(x_fp32_s4, scale, zp, dtype) self.assertEqual(x_q_s4, x_fp32_s4_ref) + @override_qengines + def test_custom_module_lstm(self): + qengine = torch.backends.quantized.engine + + batch_size = 4 + seq_len = 8 + input_size = 12 + + hidden_size = 8 + num_layers = 2 + + dropout = 0 # This is not supported + + Bias = [False, True] + Batch_first = [False, True] + Bidirectional = [False, True] + + dtype = np.uint8 + qtype = torch.quint8 + + custom_module_config = { + 'float_to_observed_custom_module_class': { + torch.nn.LSTM: torch.nn.quantizable.LSTM + } + } + + x = np.random.randn(seq_len, batch_size, input_size) + scale, zero_point = _calculate_dynamic_qparams(x, dtype=dtype) + x = torch.from_numpy(x).to(torch.float) + qx = torch.quantize_per_tensor(x, scale=scale, zero_point=zero_point, + dtype=qtype) + x = qx.dequantize() + + with torch.no_grad(): + for bias, batch_first, bidirectional in itertools.product( + Bias, Batch_first, Bidirectional): + # Assume 12dB is sufficient for functional equivalence + # Without the bias, linear performs poorly + min_power = 10 if bias else 5 + max_mse = 5e-6 if bias else 5e-1 + + if batch_first: + x = x.reshape(batch_size, seq_len, input_size) + qx = qx.reshape(batch_size, seq_len, input_size) + else: + x = x.reshape(seq_len, batch_size, input_size) + qx = qx.reshape(seq_len, batch_size, input_size) + + lstm = torch.nn.Sequential( + torch.nn.LSTM(input_size, hidden_size, + num_layers=num_layers, + bias=bias, batch_first=batch_first, + dropout=dropout, + bidirectional=bidirectional)) + lstm.eval() + y_ref = lstm(x) + + # Prepare + lstm.qconfig = torch.quantization.get_default_qconfig(qengine) + lstm_prepared = torch.quantization.prepare( + lstm, prepare_custom_config_dict=custom_module_config) + self.assertTrue(hasattr(lstm_prepared[0], 'layers')) + self.assertEqual(num_layers, len(lstm_prepared[0].layers)) + + # Calibrate + y = lstm_prepared(x) + self.assertEqual(y_ref, y) + + # Quantize + lstm_quantized = torch.quantization.convert(lstm_prepared) + qy = lstm_quantized(qx) + + snr = _snr(y, qy) + snr = [snr[0]] + snr[1] + + for signal, mse, power in snr: + self.assertTrue( + power > min_power or mse < max_mse, + msg=(f"Error is too high: SNR(dB): {power}, " + f"Signal: {signal}, MSE: {mse}")) + class TestDynamicQuantizedLinear(TestCase): """Tests the correctness of the dynamic quantized linear and linear_relu op.""" @@ -3346,7 +3427,7 @@ def _make_qconv_tensors( self, batch_size, input_channels_per_group, input_feature_map_shape, output_channels_per_group, groups, kernels, strides, pads, dilations, X_scale, X_zero_point, W_scale, W_zero_point, - use_bias, use_channelwise, use_transpose, memory_format=torch.contiguous_format + use_bias, use_channelwise, use_transpose ): assert not (use_channelwise and use_transpose), \ "Cannot generate channelwise qconv_transpose_tensors " @@ -3394,7 +3475,6 @@ def _make_qconv_tensors( (batch_size, input_channels,) + input_feature_map_shape, ) X = X_scale * (X_init - X_zero_point).float() - X = X.to(memory_format=memory_format) if use_channelwise: W_shape = (-1, 1) + (1,) * len(kernels) @@ -3427,15 +3507,13 @@ def _test_qconv_impl( input_channels_per_group, input_feature_map_shape, output_channels_per_group, groups, kernels, strides, pads, o_pads, dilations, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, - Y_zero_point, use_bias, use_relu, use_channelwise, use_transpose, - memory_format=torch.contiguous_format + Y_zero_point, use_bias, use_relu, use_channelwise, use_transpose ): (X, W), (X_q, W_q), bias_float = self._make_qconv_tensors( batch_size, input_channels_per_group, input_feature_map_shape, output_channels_per_group, groups, kernels, strides, pads, dilations, X_scale, X_zero_point, W_scale, - W_zero_point, use_bias, use_channelwise, use_transpose, - memory_format) + W_zero_point, use_bias, use_channelwise, use_transpose) # Assign weights W = W_q.dequantize() X = X_q.dequantize() @@ -3483,14 +3561,6 @@ def _test_qconv_impl( pads: {pads}, o_pads: {o_pads}, dilations: {dilations}, groups: {groups}, y_s: {Y_scale}, y_zp: {Y_zero_point}''') - # fbgemm for now forces output to be NHWC (channels last) to opportunistically - # improve performance - if torch.backends.quantized.engine == 'qnnpack': - # Make sure memory format is preserved - self.assertEqual( - X_q.is_contiguous(memory_format=memory_format), - Y_q.is_contiguous(memory_format=memory_format)) - # Return the quantized data for later reuse return X_q, W_q, bias_float @@ -3563,14 +3633,12 @@ def test_qconv2d( dilations, groups, ) - for memory_format in (torch.contiguous_format, torch.channels_last): - self._test_qconv_impl( - qconv, qconv_prepack, conv_op, batch_size, - input_channels_per_group, (height, width), - output_channels_per_group, groups, kernels, strides, pads, None, - dilations, X_scale, X_zero_point, W_scale, W_zero_point, - Y_scale, Y_zero_point, use_bias, use_relu, use_channelwise, False, - memory_format) + self._test_qconv_impl( + qconv, qconv_prepack, conv_op, batch_size, + input_channels_per_group, (height, width), + output_channels_per_group, groups, kernels, strides, pads, None, + dilations, X_scale, X_zero_point, W_scale, W_zero_point, + Y_scale, Y_zero_point, use_bias, use_relu, use_channelwise, False) """Tests the correctness of quantized convolution op.""" @given(batch_size=st.integers(1, 3), @@ -4163,7 +4231,6 @@ def test_qconv3d_unpack( (stride_d, stride_h, stride_w), (pad_d, pad_h, pad_w), (o_pad, o_pad, o_pad), channelwise) - class TestPadding(TestCase): @given(batch_size=st.integers(1, 64), channels=st.integers(1, 64), diff --git a/test/test_autograd.py b/test/test_autograd.py index 7ce73683835f..34c38eefa342 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -32,7 +32,8 @@ from torch.testing._internal.common_utils import (TestCase, run_tests, skipIfNoLapack, suppress_warnings, slowTest, load_tests, random_symmetric_matrix, - IS_WINDOWS, IS_MACOS, CudaMemoryLeakCheck) + IS_WINDOWS, IS_MACOS, CudaMemoryLeakCheck, + TemporaryFileName) from torch.autograd import Variable, Function, detect_anomaly, kineto_available from torch.autograd.function import InplaceFunction import torch.autograd.forward_ad as fwAD @@ -2970,6 +2971,20 @@ def run_test(input_size, norm_deg): run_test((10,), 3) run_test((10,), 1) run_test((10,), 1.5) + run_test((10,), inf) + + def test_norm_inf_subgradient(self): + def run_test(input, expected, dim=None): + x = torch.tensor(input, requires_grad=True) + out = x.norm(inf, dim=dim, keepdim=True) + out.backward(torch.ones(out.size())) + self.assertEqual(x.grad, expected) + + run_test([0., 0., 0.], [0., 0., 0.]) + run_test([1., 0., 1.], [0.5, 0., 0.5]) + run_test([[1., 0., 1.], [0., 1., 1.]], [[0.25, 0., 0.25], [0., 0.25, 0.25]]) + run_test([[1., 0., 1.], [0., 1., 0.]], [[0.5, 0., 0.5], [0., 1., 0.]], (1,)) + run_test(torch.ones((2, 2, 2)), torch.full((2, 2, 2), 0.25), (0, 2)) def test_pow_zero_tensor_gradient(self): def run_test(input_size, exponent): @@ -3015,18 +3030,17 @@ def gen_matrices(p): gradgradcheck(torch.chain_matmul, gen_matrices([3, 5, 2, 6])) gradgradcheck(torch.chain_matmul, gen_matrices([6, 2, 4, 8, 10])) - @unittest.skipIf(IS_WINDOWS, """File open permission error on Windows, - https://github.com/pytorch/pytorch/issues/34086""") def test_profiler_tracing(self): t1, t2 = torch.ones(1), torch.ones(1) with torch.autograd.profiler.profile(use_kineto=kineto_available()) as prof: torch.add(t1, t2) - with tempfile.NamedTemporaryFile(mode="w+") as f: - prof.export_chrome_trace(f.name) + with TemporaryFileName(mode="w+") as fname: + prof.export_chrome_trace(fname) # read the trace and expect valid json # if the JSON generated by export_chrome_trace is not valid, this will throw and fail the test. - json.load(f) + with io.open(fname, 'r') as f: + json.load(f) # Same test but for cuda. if not torch.cuda.is_available(): @@ -3037,10 +3051,11 @@ def test_profiler_tracing(self): with torch.autograd.profiler.profile(use_cuda=True, use_kineto=kineto_available()) as prof: torch.add(t1, t2) - with tempfile.NamedTemporaryFile(mode="w+") as f: - prof.export_chrome_trace(f.name) + with TemporaryFileName(mode="w+") as fname: + prof.export_chrome_trace(fname) # Now validate the json - json.load(f) + with io.open(fname, 'r') as f: + json.load(f) def test_profiler(self): x = torch.randn(10, 10) diff --git a/test/test_jit.py b/test/test_jit.py index 1acdbb42f54e..a683a8eb0b8c 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -35,6 +35,7 @@ from jit.test_slice import TestSlice # noqa: F401 from jit.test_warn import TestWarn # noqa: F401 from jit.test_isinstance import TestIsinstance # noqa: F401 +from jit.test_cuda import TestCUDA # noqa: F401 from jit.test_hash import TestHash # noqa: F401 # Torch @@ -2393,8 +2394,7 @@ def fn(x): warns = [str(w.message) for w in warns] self.assertEqual(len(warns), 0) - @unittest.skipIf(IS_WINDOWS or True, "TODO: need to fix this test case for " - "Windows, re-enable with https://github.com/pytorch/pytorch/pull/29339") + @unittest.skipIf(True, "TODO: re-enable with https://github.com/pytorch/pytorch/pull/29339") def test_torch_load_error(self): class J(torch.jit.ScriptModule): def __init__(self): @@ -2405,20 +2405,20 @@ def forward(self, input): return input + 100 j = J() - with tempfile.NamedTemporaryFile() as f: - j.save(f.name) + with TemporaryFileName() as fname: + j.save(fname) with self.assertRaisesRegex(RuntimeError, "is a zip"): - torch.load(f.name) + torch.load(fname) - @unittest.skipIf(IS_WINDOWS, "TODO: need to fix this test case for Windows") def test_torch_load_zipfile_check(self): @torch.jit.script def fn(x): return x + 10 - with tempfile.NamedTemporaryFile() as f: - fn.save(f.name) - self.assertTrue(torch.serialization._is_zipfile(f)) + with TemporaryFileName() as fname: + fn.save(fname) + with io.open(fname, 'rb') as f: + self.assertTrue(torch.serialization._is_zipfile(f)) def test_python_bindings(self): lstm_cell = torch.jit.script(LSTMCellS) diff --git a/test/test_nn.py b/test/test_nn.py index 1d63be6e3075..386ba369dca6 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -9283,18 +9283,19 @@ def test_flatten(self): def test_unflatten(self): tensor_input = torch.randn(2, 50) - # Unflatten Tensor + # Unflatten Tensor (unflattened_size as a tuple of ints and list of ints) - unflatten = nn.Unflatten(dim=1, unflattened_size=(2, 5, 5)) - tensor_output = unflatten(tensor_input) - self.assertEqual(tensor_output.size(), torch.Size([2, 2, 5, 5])) + for us in ((2, 5, 5), [2, 5, 5]): + unflatten = nn.Unflatten(dim=1, unflattened_size=us) + tensor_output = unflatten(tensor_input) + self.assertEqual(tensor_output.size(), torch.Size([2, 2, 5, 5])) # Unflatten NamedTensor unflatten = nn.Unflatten(dim='features', unflattened_size=(('C', 2), ('H', 5), ('W', 5))) named_tensor_input = tensor_input.refine_names('N', 'features') named_tensor_output = unflatten(named_tensor_input) - self.assertEqual(tensor_output.size(), torch.Size([2, 2, 5, 5])) + self.assertEqual(named_tensor_output.size(), torch.Size([2, 2, 5, 5])) def test_unflatten_invalid_arg(self): # Wrong type for unflattened_size (tuple of floats) @@ -9304,6 +9305,13 @@ def test_unflatten_invalid_arg(self): r"unflattened_size must be tuple of ints, but found element of type float at pos 2"): nn.Unflatten(dim=1, unflattened_size=(2, 5, 5.0)) + # Wrong type for unflattened_size (list of lists and list of tuples) + for us in ([['C', 2], ['W', 5], ['H', 5]], [('C', 2), ('W', 5), ('H', 5)]): + with self.assertRaisesRegex( + TypeError, + r"unflattened_size must be a tuple of tuples, but found type list"): + nn.Unflatten(dim='features', unflattened_size=us) + # Wrong type for unflattened_size (tuple of lists) with self.assertRaisesRegex( @@ -9311,19 +9319,12 @@ def test_unflatten_invalid_arg(self): r"unflattened_size must be tuple of tuples, but found element of type list at pos 0"): nn.Unflatten(dim='features', unflattened_size=(['C', 2], ['W', 5], ['H', 5])) - # Wrong type for unflattened_size (list of ints) - - with self.assertRaisesRegex( - TypeError, - r"unflattened_size must be a tuple of ints, but found type list"): - nn.Unflatten(dim=1, unflattened_size=[2, 5, 5]) - - # Wrong type for unflattened_size (list of lists) + # Wrong type for unflattened_size (tuple of dicts) with self.assertRaisesRegex( TypeError, - r"unflattened_size must be a tuple of tuples, but found type list"): - nn.Unflatten(dim='features', unflattened_size=[['C', 2], ['W', 5], ['H', 5]]) + r"unflattened_size must be tuple of tuples, but found element of type dict at pos 0"): + nn.Unflatten(dim='features', unflattened_size=({'C': 2}, {'W': 5}, {'H': 5})) def test_layer_norm_grads_with_create_graph_flag(self): atol = 1e-5 diff --git a/test/test_overrides.py b/test/test_overrides.py index 95f94504d84e..f32b04cb2e53 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -563,6 +563,8 @@ def instance_gen(): func_args.append(instance_gen()) elif t == 'TensorList': func_args.append([instance_gen(), instance_gen()]) + elif t == 'c10::List>': + func_args.append([instance_gen(), instance_gen()]) elif t == 'IntArrayRef': size = arg.get('size', 2) if size == 1: diff --git a/test/test_profiler.py b/test/test_profiler.py index d24fabe76998..826a9f5d0b57 100644 --- a/test/test_profiler.py +++ b/test/test_profiler.py @@ -1,14 +1,14 @@ import collections import gc +import io import unittest -import tempfile import torch import torch.nn as nn import torch.optim import torch.utils.data from torch.testing._internal.common_utils import ( - TestCase, run_tests, TEST_WITH_ASAN, IS_WINDOWS) + TestCase, run_tests, TEST_WITH_ASAN, IS_WINDOWS, TemporaryFileName) import torch.autograd.profiler as profiler from torch.autograd.profiler import profile from torch.autograd import kineto_available @@ -282,7 +282,6 @@ def trace_handler(p): print(p.key_averages().table( sort_by="self_cuda_time_total", row_limit=-1)) - @unittest.skipIf(IS_WINDOWS, "Disabled on windows (permissions)") def test_export_stacks(self): with profile(with_stack=True, use_kineto=kineto_available()) as p: x = torch.randn(10, 10) @@ -290,9 +289,10 @@ def test_export_stacks(self): z = torch.mm(x, y) z = z + y - with tempfile.NamedTemporaryFile(mode="w+") as f: - p.export_stacks(f.name) - lines = f.readlines() + with TemporaryFileName(mode="w+") as fname: + p.export_stacks(fname) + with io.open(fname, 'r') as f: + lines = f.readlines() assert len(lines) > 0, "Empty stacks file" for line in lines: is_int = False diff --git a/test/test_quantization.py b/test/test_quantization.py index f68bfcd058b6..1c370913c6d0 100644 --- a/test/test_quantization.py +++ b/test/test_quantization.py @@ -15,6 +15,7 @@ from quantization.test_quantized_op import TestPadding # noqa: F401 from quantization.test_quantized_op import TestQuantizedEmbeddingOps # noqa: F401 from quantization.test_quantized_op import TestDynamicQuantizedRNNOp # noqa: F401 + # Quantized Functional from quantization.test_quantized_functional import TestQuantizedFunctional # noqa: F401 diff --git a/test/test_serialization.py b/test/test_serialization.py index 8fd5926caa82..916f133c3fe1 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -17,7 +17,7 @@ from torch.serialization import check_module_version_greater_or_equal from torch.testing._internal.common_utils import TestCase, IS_WINDOWS, \ - TEST_DILL, run_tests, download_file, BytesIOContext + TEST_DILL, run_tests, download_file, BytesIOContext, TemporaryFileName from torch.testing._internal.common_device_type import instantiate_device_type_tests # These tests were all copied from `test/test_torch.py` at some point, so see @@ -137,25 +137,22 @@ def test(name_or_buffer): with tempfile.NamedTemporaryFile() as f: test(f) - if sys.platform != "win32": - with tempfile.NamedTemporaryFile() as f: - test(f.name) + with TemporaryFileName() as fname: + test(fname) test(io.BytesIO()) def test_serialization(self): # Test serialization with a real file b = self._test_serialization_data() - for use_name in (False, True): - # Passing filename to torch.save(...) will cause the file to be opened twice, - # which is not supported on Windows - if sys.platform == "win32" and use_name: - continue - with tempfile.NamedTemporaryFile() as f: - handle = f if not use_name else f.name - torch.save(b, handle) - f.seek(0) - c = torch.load(handle) + with tempfile.NamedTemporaryFile() as f: + torch.save(b, f) + f.seek(0) + c = torch.load(f) + self._test_serialization_assert(b, c) + with TemporaryFileName() as fname: + torch.save(b, fname) + c = torch.load(fname) self._test_serialization_assert(b, c) # test non-ascii encoding of bytes arrays/strings # The following bytes are produced by serializing @@ -716,9 +713,8 @@ def test(name_or_buffer): with tempfile.NamedTemporaryFile() as f: test(f) - if sys.platform != "win32": - with tempfile.NamedTemporaryFile() as f: - test(f.name) + with TemporaryFileName() as fname: + test(fname) test(io.BytesIO()) @@ -737,12 +733,11 @@ def test_serialization_2gb_file(self): f.seek(0) state = torch.load(f) - @unittest.skipIf(IS_WINDOWS, "torch.save with filename will open file twice, not supported in Windows.") def test_pathlike_serialization(self): model = torch.nn.Conv2d(20, 3200, kernel_size=3) - with tempfile.NamedTemporaryFile() as f: - path = pathlib.Path(f.name) + with TemporaryFileName() as fname: + path = pathlib.Path(fname) torch.save(model, path) torch.load(path) diff --git a/test/test_throughput_benchmark.py b/test/test_throughput_benchmark.py index d2f993ddaa3a..9d60344b5912 100644 --- a/test/test_throughput_benchmark.py +++ b/test/test_throughput_benchmark.py @@ -1,10 +1,9 @@ import torch -import tempfile from torch.utils import ThroughputBenchmark from torch.testing import assert_allclose -from torch.testing._internal.common_utils import run_tests, TestCase +from torch.testing._internal.common_utils import run_tests, TestCase, TemporaryFileName class TwoLayerNet(torch.jit.ScriptModule): def __init__(self, D_in, H, D_out): @@ -76,8 +75,8 @@ def test_module(self): self.linear_test(TwoLayerNetModule) def test_profiling(self): - with tempfile.NamedTemporaryFile(delete=False) as f: - self.linear_test(TwoLayerNetModule, profiler_output_path=f.name) + with TemporaryFileName() as fname: + self.linear_test(TwoLayerNetModule, profiler_output_path=fname) if __name__ == '__main__': diff --git a/test/test_torch.py b/test/test_torch.py index 6532c2e5e17d..8872516ddd28 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -6870,7 +6870,6 @@ def inner(self, device, dtype): ('rot90', 'k1_d12', _small_3d, lambda t, d: [1, [1, 2]], 1e-5, 1e-5, 1e-5, _types + _complex_types, _cpu_types, False), ('rot90', 'k1_neg_d', _small_3d, lambda t, d: [1, [1, -1]], 1e-5, 1e-5, 1e-5, _types + _complex_types, _cpu_types, False), ('rot90', 'default', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types + _complex_types, _cpu_types, False), - ('rsqrt', '', lambda t, d: _small_3d(t, d) + 1, lambda t, d: [], 1e-2, 1e-5, 1e-4, _float_types_no_half), ('sinh', '', lambda t, d: _small_3d(t, d).clamp(-1, 1), lambda t, d: [], 1e-3, 1e-5, 1e-5, _float_types), ('tan', '', lambda t, d: _small_3d(t, d).clamp(-1, 1), lambda t, d: [], 1e-3, 1e-5, 1e-5, _float_types), ('tan', 'complex', lambda t, d: _small_3d(t, d), lambda t, d: [], 1e-3, 1e-5, 1e-5, _complex_types), diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index 776482306f4d..1daecc24f79f 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -1715,7 +1715,6 @@ def _medium_2d(dtype, device): _TorchMathTestMeta('ceil'), _TorchMathTestMeta('rad2deg'), _TorchMathTestMeta('deg2rad'), - _TorchMathTestMeta('rsqrt', reffn=lambda x: np.reciprocal(np.sqrt(x))), _TorchMathTestMeta('frac', reffn='fmod', refargs=lambda x: (x.numpy(), 1)), _TorchMathTestMeta('trunc'), _TorchMathTestMeta('round'), diff --git a/tools/autograd/gen_autograd_functions.py b/tools/autograd/gen_autograd_functions.py index a22154b5c01d..4724b99a8742 100644 --- a/tools/autograd/gen_autograd_functions.py +++ b/tools/autograd/gen_autograd_functions.py @@ -141,7 +141,7 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str compute_index_ranges: List[str] = [] for arg in info.args_with_derivatives: - if arg.type == 'TensorList': + if arg.type == 'TensorList' or arg.type == 'const c10::List> &': size = f'{arg.name}_size_' saved_list_sizes.append(f'size_t {arg.name}_size_;') else: @@ -166,6 +166,15 @@ def save_var(var: SavedAttribute, is_output: bool) -> None: release_variables.append(f'{name}_released_ = true;') unpack.append(f'auto {name} = unpack_list({name}_);') asserts.append(f'TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);') + elif var.type == 'c10::List>': + saved_variables.append(f'std::vector {name}_;') + saved_variables.append(f'bool {name}_released_ = false;') + # Just clear() is sufficient, we don't need to loop and clear each variable. + # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well. + release_variables.append(f'{name}_.clear();') + release_variables.append(f'{name}_released_ = true;') + unpack.append(f'auto {name} = unpack_opt_list({name}_);') + asserts.append(f'TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);') elif var.type == 'IntArrayRef': saved_variables.append(f'std::vector {name};') elif var.type == 'c10::optional': diff --git a/tools/autograd/gen_trace_type.py b/tools/autograd/gen_trace_type.py index b2dfe2667128..78c460843d94 100644 --- a/tools/autograd/gen_trace_type.py +++ b/tools/autograd/gen_trace_type.py @@ -112,9 +112,8 @@ def dispatch_trace_input(arg: Union[Argument, TensorOptionsArguments]) -> Sequen ] else: name = arg.name - # XXX: For arg that have type of Tensor?[], tracer will pass allow_undefined to addInputs if str(arg.type) == 'Tensor?[]': - return [f'jit::tracer::addInputs(node, "{name}", {name}, true);'] + return [f'jit::tracer::addInputs(node, "{name}", {name});'] else: return [ADD_TRACE_INPUT.substitute(name=name, input=name)] diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 1d75ae46e9c9..f97fb55ab012 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -118,6 +118,21 @@ } """) +SAVE_OPTIONALTENSORLIST_STORAGE = CodeTemplate("""\ +std::vector> ${tensorlist_name}_storage_saved(${tensorlist_name}.size()); +for (const c10::optional& tensor : ${tensorlist_name}) + ${tensorlist_name}_storage_saved.push_back( + tensor.has_value() && tensor->has_storage() ? c10::optional(tensor->storage()) : c10::nullopt); +""") + +ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE = CodeTemplate("""\ +for (size_t i=0; i<${tensorlist_name}.size(); i++) { + if (${tensorlist_name}_storage_saved[i].has_value()) + AT_ASSERT(${tensorlist_name}_storage_saved[i].value().is_alias_of( + static_cast>(${tensorlist_name}[i])->storage())); +} +""") + SAVE_TENSOR_IMPL = CodeTemplate("""\ c10::intrusive_ptr ${tensor_name}_impl_saved; if (${tensor_name}.defined()) ${tensor_name}_impl_saved = ${tensor_name}.getIntrusivePtr(); @@ -140,6 +155,21 @@ } """) +SAVE_OPTIONALTENSORLIST_IMPL = CodeTemplate("""\ +std::vector> ${tensorlist_name}_impl_saved(${tensorlist_name}.size()); +for (size_t i=0; i<${tensorlist_name}.size(); i++) { + c10::optional t = ${tensorlist_name}[i]; + if (t.has_value() && t->defined()) ${tensorlist_name}_impl_saved[i] = t->getIntrusivePtr(); +} +""") + +ENFORCE_SAME_OPTIONALTENSORLIST_IMPL = CodeTemplate("""\ +for (size_t i=0; i<${tensorlist_name}.size(); i++) { + if (${tensorlist_name}_impl_saved[i]) + AT_ASSERT(${tensorlist_name}_impl_saved[i] == static_cast>(${tensorlist_name}[i])->getIntrusivePtr()); +} +""") + # The following list contains functions that we don't enforce the invariant on. DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE = { # These functions are expected to change impl or storage of input tensors @@ -466,7 +496,8 @@ def emit_save_inputs(): if func is None: return setup - has_tensorlist_arg = any(arg.type == 'TensorList' for arg in func.args_with_derivatives) + has_tensorlist_arg = \ + any(arg.type in ['TensorList', 'const c10::List> &'] for arg in func.args_with_derivatives) # We don't want to save tensors if we know that they will never be used # when computing the derivative, so we add guards to those statements @@ -515,7 +546,7 @@ def guard_for(arg: SavedAttribute) -> Optional[str]: setup.extend(save_variables(func.all_saved_inputs, False, guard_for)) for arg in func.args_with_derivatives: - if arg.type == 'TensorList': + if arg.type in ['TensorList', 'const c10::List> &']: setup.append(f'grad_fn->{arg.name}_size_ = {arg.name}.size();') return setup @@ -554,7 +585,7 @@ def emit_check_if_in_complex_autograd_allowlist(): return body for arg in differentiable_outputs: name = arg['name'] - if arg['type'] == 'Tensor' or arg['type'] == 'TensorList': + if arg['type'] in ['Tensor', 'TensorList', 'const c10::List> &']: body.append('throw_error_for_complex_autograd({}, "{}");'.format(name, base_name)) return body @@ -599,7 +630,7 @@ def save_variables( expr = f'SavedVariable({var}, {str(is_output).lower()}, {is_inplace_view})' else: expr = f'SavedVariable({var}, {str(is_output).lower()})' - elif arg.type == 'TensorList': + elif arg.type in ['TensorList', 'c10::List>']: name += '_' expr = f'make_saved_variable_list({arg.name})' elif arg.type == 'IntArrayRef': @@ -699,7 +730,7 @@ def wrap_output(return_values, var): # Only allow rebasing of the history if we return a single Tensor # If we are in a no grad block, raise a warning # See NOTE [ View + Inplace detection ] for more details about this logic - if return_info['dynamic_type'] == 'TensorList': + if return_info['dynamic_type'] in ['TensorList', 'const c10::List> &']: if base_name in MULTI_OUTPUT_SAFE_FUNCTIONS: creation_meta = "CreationMeta::MULTI_OUTPUT_SAFE" else: @@ -736,6 +767,11 @@ def enforce_same_tensorimpl_and_storage(env, call): SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg)] enforce_same_ptrs_stmts += [ENFORCE_SAME_TENSORLIST_STORAGE.substitute(tensorlist_name=arg), ENFORCE_SAME_TENSORLIST_IMPL.substitute(tensorlist_name=arg)] + elif simple_type == 'c10::List>': + save_ptrs_stmts += [SAVE_OPTIONALTENSORLIST_STORAGE.substitute(tensorlist_name=arg), + SAVE_OPTIONALTENSORLIST_IMPL.substitute(tensorlist_name=arg)] + enforce_same_ptrs_stmts += [ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE.substitute(tensorlist_name=arg), + ENFORCE_SAME_OPTIONALTENSORLIST_IMPL.substitute(tensorlist_name=arg)] elif simple_type == 'Tensor': save_ptrs_stmts += [SAVE_TENSOR_STORAGE.substitute(tensor_name=arg), SAVE_TENSOR_IMPL.substitute(tensor_name=arg)] @@ -836,7 +872,7 @@ def emit_increment_version(): def unpack_args(env, declaration): def requires_unpack(arg): - return 'Tensor' in arg['dynamic_type'] + return 'Tensor' in arg['dynamic_type'] and 'c10::optional' not in arg['type'] body = [] unpacked_args = [] @@ -855,9 +891,8 @@ def requires_unpack(arg): dynamic_type = arg['dynamic_type'] if 'TensorOptions' not in dynamic_type: is_nullable = arg.get('is_nullable', False) - ref = (not is_nullable) and dynamic_type not in ['TensorList'] + ref = (not is_nullable) and dynamic_type != 'TensorList' suffix = '_opt' if is_nullable and dynamic_type != 'TensorList' else '' - body.append(UNPACK_TENSOR.substitute( arg_name=arg['name'], arg_pos=i, diff --git a/tools/autograd/templates/Functions.h b/tools/autograd/templates/Functions.h index 03240e2a5a2b..0540bb65b33b 100644 --- a/tools/autograd/templates/Functions.h +++ b/tools/autograd/templates/Functions.h @@ -32,6 +32,15 @@ inline std::vector unpack_list(at::ArrayRef xs) { }); } +inline c10::List> unpack_opt_list(at::ArrayRef xs) { + torch::List> result; + result.reserve(xs.size()); + for (const SavedVariable& v : xs) { + result.push_back(v.unpack()); + } + return result; +} + struct TypeAndSize { TypeAndSize() : options(at::TensorOptions()) {} /* implicit */ diff --git a/tools/autograd/templates/VariableType.h b/tools/autograd/templates/VariableType.h index 9062a4d08e34..fc8ffa5799c1 100644 --- a/tools/autograd/templates/VariableType.h +++ b/tools/autograd/templates/VariableType.h @@ -49,7 +49,6 @@ namespace VariableType { at::Tensor & unpack(Tensor & t, const char * name, int pos); const at::Tensor & unpack(const Tensor & t, const char * name, int pos); at::Tensor unpack_opt(const Tensor & t, const char * name, int pos); - c10::optional unpack_opt(const c10::optional & t, const char * name, int pos); std::vector unpack(at::TensorList tl, const char *name, int pos); }; diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index a214684ab29c..8eeffe724c8e 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -408,6 +408,7 @@ libtorch_cuda_core_sources = [ "torch/csrc/jit/codegen/cuda/transform_rfactor.cpp", "torch/csrc/jit/codegen/cuda/type.cpp", "torch/csrc/jit/tensorexpr/cuda_codegen.cpp", + "torch/csrc/jit/runtime/register_cuda_ops.cpp", ] libtorch_cuda_sources = libtorch_cuda_core_sources + [ @@ -503,7 +504,6 @@ libtorch_python_core_sources = [ "torch/csrc/MemoryFormat.cpp", "torch/csrc/QScheme.cpp", "torch/csrc/Module.cpp", - "torch/csrc/PtrWrapper.cpp", "torch/csrc/python_dimname.cpp", "torch/csrc/Size.cpp", "torch/csrc/Storage.cpp", diff --git a/tools/codegen/api/cpp.py b/tools/codegen/api/cpp.py index ffd9626601a0..c27a6768300a 100644 --- a/tools/codegen/api/cpp.py +++ b/tools/codegen/api/cpp.py @@ -104,9 +104,11 @@ def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> CType: return BaseCType("TensorList", binds) elif str(t.elem) == 'Dimname': return BaseCType("DimnameList", binds) - # TODO: do something reasonable about lists of optional tensors - elif (not local.use_c10_dispatcher().dispatcher_uses_new_style()) and str(t.elem) == 'Tensor?': - return BaseCType("TensorList", binds) + elif str(t.elem) == 'Tensor?': + if local.use_c10_dispatcher().dispatcher_uses_new_style(): + return BaseCType("const c10::List> &", binds) + else: + return BaseCType("TensorList", binds) elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) # TODO: explicitly qualify namespace here return BaseCType(f"ArrayRef<{elem.cpp_type()}>", binds) diff --git a/tools/codegen/api/native.py b/tools/codegen/api/native.py index 3b793527edd9..9781c46884e7 100644 --- a/tools/codegen/api/native.py +++ b/tools/codegen/api/native.py @@ -34,7 +34,7 @@ def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> CType: else: return ConstRefCType(BaseCType('Tensor', binds)) elif str(t) == 'Tensor?[]': - return BaseCType('TensorList', binds) + return BaseCType('const c10::List> &', binds) return cpp.argumenttype_type(t, mutable=mutable, binds=binds) def returns_type(rs: Sequence[Return]) -> str: diff --git a/tools/codegen/api/python.py b/tools/codegen/api/python.py index 059032869675..bdb31d4d8616 100644 --- a/tools/codegen/api/python.py +++ b/tools/codegen/api/python.py @@ -228,7 +228,7 @@ class PythonArgument: # Compute argument formal for python argument parsing. # Needs to be consistent with torch/csrc/utils/python_arg_parser.h. def argument_str(self, *, method: bool = False) -> str: - type_str = argument_type_str(self.type) + type_str = argument_type_str(self.type).replace('const ', '').replace(' &', '') name = self.name # s/self/input/ outside method bindings @@ -624,10 +624,9 @@ def argument_type_str(t: Type, *, simple_type: bool = False) -> str: return f'ScalarList[{size}]' if size is not None else 'ScalarList' elif str(t.elem) == 'Tensor?': if simple_type: - return 'TensorList' + return 'c10::List>' else: - # TODO: clone the old codegen behavior but does it make sense? - return 'TensorList?' + return 'const c10::List> &' elif str(t.elem) == 'Dimname': return f'DimnameList[{size}]' if size is not None else 'DimnameList' elem = argument_type_str(t.elem, simple_type=simple_type) @@ -1051,12 +1050,14 @@ def arg_parser_unpack_method(t: Type, has_default: bool) -> str: return 'toDimnameListOptional' elif isinstance(t, ListType): - if str(t.elem) == 'Tensor' or str(t.elem) == 'Tensor?': + if str(t.elem) == 'Tensor': # accept and use definite size if t.size is not None: return f'tensorlist_n<{t.size}>' else: return 'tensorlist' + elif str(t.elem) == 'Tensor?': + return 'list_of_optional_tensors' elif str(t.elem) == 'Dimname': # accept definite size return 'dimnamelist' diff --git a/torch/_C/_autograd.pyi b/torch/_C/_autograd.pyi index cfcb66896ad7..15a286f2370c 100644 --- a/torch/_C/_autograd.pyi +++ b/torch/_C/_autograd.pyi @@ -25,7 +25,8 @@ class ProfilerConfig: state: ProfilerState, report_input_shapes: bool, profile_memory: bool, - with_stack: bool + with_stack: bool, + with_flops: bool ) -> None: ... ... diff --git a/torch/__init__.py b/torch/__init__.py index 04955623ab2a..9ae1010a3ba8 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -574,6 +574,7 @@ def _assert(condition, message): import torch.futures import torch.nn import torch.nn.intrinsic +import torch.nn.quantizable import torch.nn.quantized import torch.optim import torch.optim._multi_tensor diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 9c767822b11b..fe7237b5a370 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -1026,7 +1026,6 @@ def merge_dicts(*dicts): tensor([ 0, 1, -4], dtype=torch.int8) """.format(**common_args)) -# TODO: see https://github.com/pytorch/pytorch/issues/43667 add_docstr(torch.bmm, r""" bmm(input, mat2, *, deterministic=False, out=None) -> Tensor @@ -2934,7 +2933,6 @@ def merge_dicts(*dicts): tensor([ 0., 1.]) """.format(**common_args)) -# TODO: see https://github.com/pytorch/pytorch/issues/43667 add_docstr(torch.eye, r""" eye(n, m=None, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor @@ -2944,6 +2942,8 @@ def merge_dicts(*dicts): Args: n (int): the number of rows m (int, optional): the number of columns with default being :attr:`n` + +Keyword arguments: {out} {dtype} {layout} @@ -4174,7 +4174,6 @@ def merge_dicts(*dicts): tensor([ 0.5724, 0.0000, -0.1208]) """.format(**common_args)) -# TODO: update kwargs formatting (see https://github.com/pytorch/pytorch/issues/43667) add_docstr(torch.linspace, r""" linspace(start, end, steps, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor @@ -4201,6 +4200,8 @@ def merge_dicts(*dicts): start (float): the starting value for the set of points end (float): the ending value for the set of points steps (int): size of the constructed tensor + +Keyword arguments: {out} {dtype} {layout} @@ -4537,7 +4538,6 @@ def merge_dicts(*dicts): tensor([ True, True, False, False]) """.format(**common_args)) -# TODO: update kwargs formatting (see https://github.com/pytorch/pytorch/issues/43667) add_docstr(torch.logspace, """ logspace(start, end, steps, base=10.0, *, \ out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor @@ -4568,7 +4568,9 @@ def merge_dicts(*dicts): start (float): the starting value for the set of points end (float): the ending value for the set of points steps (int): size of the constructed tensor - base (float): base of the logarithm function. Default: ``10.0``. + base (float, optional): base of the logarithm function. Default: ``10.0``. + +Keyword arguments: {out} {dtype} {layout} @@ -5469,36 +5471,15 @@ def merge_dicts(*dicts): add_docstr(torch.argmin, r""" -argmin(input) -> LongTensor +argmin(input, dim=None, keepdim=False) -> LongTensor -Returns the indices of the minimum value of all elements in the :attr:`input` tensor. +Returns the indices of the minimum value(s) of the flattened tensor or along a dimension This is the second value returned by :meth:`torch.min`. See its documentation for the exact semantics of this method. .. note:: If there are multiple minimal values then the indices of the first minimal value are returned. -Args: - {input} - -Example:: - - >>> a = torch.randn(4, 4) - >>> a - tensor([[ 0.1139, 0.2254, -0.1381, 0.3687], - [ 1.0100, -1.1975, -0.0102, -0.4732], - [-0.9240, 0.1207, -0.7506, -1.0213], - [ 1.7809, -1.2960, 0.9384, 0.1438]]) - >>> torch.argmin(a) - tensor(13) - -.. function:: argmin(input, dim, keepdim=False) -> LongTensor - -Returns the indices of the minimum values of a tensor across a dimension. - -This is the second value returned by :meth:`torch.min`. See its -documentation for the exact semantics of this method. - Args: {input} {dim} If ``None``, the argmin of the flattened input is returned. @@ -5512,8 +5493,15 @@ def merge_dicts(*dicts): [ 1.0100, -1.1975, -0.0102, -0.4732], [-0.9240, 0.1207, -0.7506, -1.0213], [ 1.7809, -1.2960, 0.9384, 0.1438]]) + >>> torch.argmin(a) + tensor(13) >>> torch.argmin(a, dim=1) tensor([ 2, 1, 3, 1]) + >>> torch.argmin(a, dim=1, keepdim=True) + tensor([[2], + [1], + [3], + [1]]) """.format(**single_dim_common)) add_docstr(torch.mm, @@ -6328,7 +6316,6 @@ def merge_dicts(*dicts): """.format(**common_args)) -# TODO: see https://github.com/pytorch/pytorch/issues/43667 add_docstr(torch.ones, r""" ones(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor @@ -6339,6 +6326,8 @@ def merge_dicts(*dicts): Args: size (int...): a sequence of integers defining the shape of the output tensor. Can be a variable number of arguments or a collection like a list or tuple. + +Keyword arguments: {out} {dtype} {layout} @@ -6356,7 +6345,6 @@ def merge_dicts(*dicts): """.format(**factory_common_args)) -# TODO: see https://github.com/pytorch/pytorch/issues/43667 add_docstr(torch.ones_like, r""" ones_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor @@ -6372,6 +6360,8 @@ def merge_dicts(*dicts): Args: {input} + +Keyword arguments: {dtype} {layout} {device} @@ -8260,7 +8250,7 @@ def merge_dicts(*dicts): Args: input (Tensor): the input tensor of size :math:`(*, n, n)` where `*` is zero or more batch dimensions consisting of symmetric matrices. - eigenvectors(boolean, optional): controls whether eigenvectors have to be computed + eigenvectors(bool, optional): controls whether eigenvectors have to be computed upper(boolean, optional): controls whether to consider upper-triangular or lower-triangular region Keyword args: @@ -9270,7 +9260,7 @@ def merge_dicts(*dicts): add_docstr(torch.full_like, """ -full_like(input, fill_value, \\*, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, \ +full_like(input, fill_value, \\*, dtype=None, layout=torch.strided, device=None, requires_grad=False, \ memory_format=torch.preserve_format) -> Tensor Returns a tensor with the same size as :attr:`input` filled with :attr:`fill_value`. @@ -9489,9 +9479,10 @@ def merge_dicts(*dicts): Batched version for complex inputs is only supported on the CPU. Arguments: - input (Tensor): The input tensor of size :math:`(*, m, n)` where :math:`*` is zero or more batch dimensions - rcond (float): A floating point value to determine the cutoff for small singular values. - Default: 1e-15 + input (Tensor): The input tensor of size :math:`(*, m, n)` where :math:`*` is + zero or more batch dimensions. + rcond (float, optional): A floating point value to determine the cutoff for + small singular values. Default: ``1e-15``. Returns: The pseudo-inverse of :attr:`input` of dimensions :math:`(*, n, m)` @@ -9887,6 +9878,8 @@ def merge_dicts(*dicts): Arguments: y (Tensor): The values of the function to integrate + +Keyword args: dx (float): The distance between points at which `y` is sampled. dim (int): The dimension along which to integrate. By default, use the last dimension. diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index a5c078e84f4c..a3d0da1aef9d 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -468,7 +468,8 @@ def config(self): self.profiler_kind, self.record_shapes, self.profile_memory, - self.with_stack) + self.with_stack, + self.with_flops) def __enter__(self): if not self.enabled: @@ -746,6 +747,7 @@ def __enter__(self): torch.autograd.ProfilerState.NVTX, self.record_shapes, False, + False, False) ) return self diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index aeac5bafd56f..f70bd1a0ad95 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -722,7 +722,6 @@ PyObject* initModule() { methods.data() }; ASSERT_TRUE(module = PyModule_Create(&torchmodule)); - ASSERT_TRUE(THPWrapper_init(module)); ASSERT_TRUE(THPGenerator_init(module)); ASSERT_TRUE(THPException_init(module)); THPSize_init(module); diff --git a/torch/csrc/PtrWrapper.cpp b/torch/csrc/PtrWrapper.cpp deleted file mode 100644 index aa48c49949b9..000000000000 --- a/torch/csrc/PtrWrapper.cpp +++ /dev/null @@ -1,102 +0,0 @@ -#include -#include -#include - -static PyObject* THPWrapperClass = nullptr; - -struct THPWrapper { - PyObject_HEAD - void *data; - void (*destructor)(void*); -}; - -PyObject * THPWrapper_New(void *data, void (*destructor)(void*)) -{ - PyObject *args = PyTuple_New(0); - if (!args) { - return nullptr; - } - PyObject *result = PyObject_Call(THPWrapperClass, args, nullptr); - if (result) { - THPWrapper* wrapper = (THPWrapper*) result; - wrapper->data = data; - wrapper->destructor = destructor; - } - Py_DECREF(args); - return result; -} - -bool THPWrapper_check(PyObject * obj) -{ - return (PyObject*)Py_TYPE(obj) == THPWrapperClass; -} - -void * THPWrapper_get(PyObject * obj) -{ - return ((THPWrapper*)obj)->data; -} - -static PyObject * THPWrapper_pynew(PyTypeObject *type, PyObject *args, PyObject *kwargs) -{ - PyObject* self = type->tp_alloc(type, 0); - THPWrapper* wrapper = (THPWrapper*) self; - wrapper->data = nullptr; - wrapper->destructor = nullptr; - return self; -} - -static void THPWrapper_dealloc(THPWrapper* self) -{ - self->destructor(self->data); - Py_TYPE(self)->tp_free((PyObject*)self); -} - -PyTypeObject THPWrapperType = { - PyVarObject_HEAD_INIT(nullptr, 0) - "torch._C._PtrWrapper", /* tp_name */ - sizeof(THPWrapper), /* tp_basicsize */ - 0, /* tp_itemsize */ - (destructor)THPWrapper_dealloc, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - nullptr, /* tp_getattr */ - nullptr, /* tp_setattr */ - nullptr, /* tp_reserved */ - nullptr, /* tp_repr */ - nullptr, /* tp_as_number */ - nullptr, /* tp_as_sequence */ - nullptr, /* tp_as_mapping */ - nullptr, /* tp_hash */ - nullptr, /* tp_call */ - nullptr, /* tp_str */ - nullptr, /* tp_getattro */ - nullptr, /* tp_setattro */ - nullptr, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - nullptr, /* tp_doc */ - nullptr, /* tp_traverse */ - nullptr, /* tp_clear */ - nullptr, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - nullptr, /* tp_iter */ - nullptr, /* tp_iternext */ - nullptr, /* tp_methods */ - nullptr, /* tp_members */ - nullptr, /* tp_getset */ - nullptr, /* tp_base */ - nullptr, /* tp_dict */ - nullptr, /* tp_descr_get */ - nullptr, /* tp_descr_set */ - 0, /* tp_dictoffset */ - nullptr, /* tp_init */ - nullptr, /* tp_alloc */ - THPWrapper_pynew, /* tp_new */ -}; - -bool THPWrapper_init(PyObject *module) -{ - THPWrapperClass = (PyObject*)&THPWrapperType; - if (PyType_Ready(&THPWrapperType) < 0) - return false; - Py_INCREF(&THPWrapperType); - return true; -} diff --git a/torch/csrc/PtrWrapper.h b/torch/csrc/PtrWrapper.h deleted file mode 100644 index 985193c74c9b..000000000000 --- a/torch/csrc/PtrWrapper.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef THP_PTR_WRAPPER_H -#define THP_PTR_WRAPPER_H - -#include - -/** - * Python wrapper around arbitrary opaque C++ class - */ - -bool THPWrapper_init(PyObject *module); - -PyObject * THPWrapper_New(void *data, void (*destructor)(void*)); -void * THPWrapper_get(PyObject * obj); -bool THPWrapper_check(PyObject * obj); - -#endif diff --git a/torch/csrc/THP.h b/torch/csrc/THP.h index edf4621765f8..26f6c06b3d20 100644 --- a/torch/csrc/THP.h +++ b/torch/csrc/THP.h @@ -31,7 +31,6 @@ #include #include #include -#include #include #include #include diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 0121fef95155..6558295d58cb 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -160,10 +161,21 @@ std::tuple _euclidean_dist_backward(const Tensor & grad, const T x2 * ratio.sum(-2, false).unsqueeze(-1) - ratio.transpose(-2, -1).matmul(x1)}; } -Tensor norm_backward(const Tensor & grad, const Tensor & self, const optional & p_, const Tensor & norm) { +Tensor norm_backward(const Tensor& grad, const Tensor& self, const optional & p_, const Tensor& norm) { + return norm_backward(grad, self, p_, norm, {}, true); +} + +Tensor norm_backward(Tensor grad, const Tensor& self, const optional & p_, Tensor norm, IntArrayRef dim, bool keepdim) { + size_t ndim = self.sizes().size(); double p = p_.value_or(2.0).toDouble(); Tensor self_scaled; Tensor scale_v; + + if (!keepdim && self.dim() != 0) { + grad = unsqueeze_multiple(grad, dim, ndim); + norm = unsqueeze_multiple(norm, dim, ndim); + } + if (p == 0.0) { return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } else if (p == 1.0) { @@ -172,8 +184,13 @@ Tensor norm_backward(const Tensor & grad, const Tensor & self, const optional & p_, Tensor norm, IntArrayRef dim, bool keepdim) { - IntArrayRef sizes = self.sizes(); - if (!keepdim && self.dim() != 0) { - if (dim.size()==1) { - grad = grad.unsqueeze(dim[0]); - norm = norm.unsqueeze(dim[0]); - } else { - auto dims_to_unsqueeze = at::dim_list_to_bitset(dim, sizes.size()); - for (size_t i = 0; i < sizes.size(); i++){ - if (dims_to_unsqueeze[i]) { - grad = grad.unsqueeze(i); - norm = norm.unsqueeze(i); - } - } - } - } - return norm_backward(grad, self, p_, norm); -} - -Tensor pow_backward(Tensor grad, const Tensor & self, const Scalar & exponent_) { - auto exponent = (exponent_.isComplex()) ? exponent_.toComplexDouble() : exponent_.toDouble(); - if (exponent == 0.0) { +Tensor pow_backward(Tensor grad, const Tensor & self, const Scalar & exponent) { + if (exponent.equal(0.0)) { return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } else { - auto out = grad * (exponent * self.pow(exponent - 1)).conj(); + auto grad_lambda = [&](auto exp) { return grad * (exp * self.pow(exp - 1)).conj(); }; + Tensor out = (exponent.isComplex()) ? grad_lambda(exponent.toComplexDouble()) : grad_lambda(exponent.toDouble()); return handle_r_to_c(self, out); } } @@ -243,9 +241,8 @@ Tensor pow_backward_exponent(Tensor grad, const Tensor& self, const Tensor& expo } Tensor pow_backward_exponent(Tensor grad, const Scalar & base, const Tensor& exponent, Tensor result) { - auto base_ = base.isComplex() ? base.toComplexDouble() : base.toDouble(); - auto grad_lambda = [](auto a, auto b) { return (a * std::log(b)).conj(); }; - if (base_ == 0.0) { + auto grad_lambda = [](Tensor a, Scalar b) { return (a * b.log()).conj(); }; + if (base.equal(0.0)) { auto cond = [](auto exp) { if (exp.is_complex()) { return at::logical_and(at::imag(exp) == 0, at::real(exp) >= 0); @@ -255,10 +252,10 @@ Tensor pow_backward_exponent(Tensor grad, const Scalar & base, const Tensor& exp }; auto out = grad * at::where(cond(exponent), at::zeros({}, grad.options()), - grad_lambda(result, base_)); + grad_lambda(result, base)); return handle_r_to_c(exponent, out); } else { - auto out = grad * grad_lambda(result, base_); + auto out = grad * grad_lambda(result, base); return handle_r_to_c(exponent, out); } } @@ -2215,15 +2212,17 @@ Tensor det_backward(const Tensor & grad, const Tensor& self, const Tensor& det) return nonsingular_case_backward(grad, self, det); } } else { - auto nonzero_det_indices = at::where(det); + auto nonzero_det_indices = at::native::toListOfOptionalTensors(at::where(det)); + c10::optional first_nonzero_det_index = nonzero_det_indices[0]; - if (nonzero_det_indices[0].size(0) == det.numel()) { // all determinants are nonzero (non-singular) + if (first_nonzero_det_index->size(0) == det.numel()) { // all determinants are nonzero (non-singular) return nonsingular_case_backward(grad, self, det); } - auto zero_det_indices = at::where(det == 0); + auto zero_det_indices = at::native::toListOfOptionalTensors(at::where(det == 0)); + c10::optional first_zero_det_index = zero_det_indices[0]; - if (zero_det_indices[0].size(0) == det.numel()) { // all determinants are zero (singular) + if (first_zero_det_index->size(0) == det.numel()) { // all determinants are zero (singular) return singular_case_backward(grad, self, det); } @@ -2265,15 +2264,17 @@ Tensor logdet_backward(const Tensor & grad, const Tensor& self, const Tensor& lo return singular_case_backward(grad, self); } } else { - auto finite_logdet_indices = at::where(logdet != -INFINITY); + auto finite_logdet_indices = at::native::toListOfOptionalTensors(at::where(logdet != -INFINITY)); + c10::optional first_finite_logdet_index = finite_logdet_indices[0]; - if (finite_logdet_indices[0].size(0) == logdet.numel()) { // all log determinants are finite (non-singular) + if (first_finite_logdet_index->size(0) == logdet.numel()) { // all log determinants are finite (non-singular) return nonsingular_case_backward(grad, self); } - auto neginf_logdet_indices = at::where(logdet == -INFINITY); + auto neginf_logdet_indices = at::native::toListOfOptionalTensors(at::where(logdet == -INFINITY)); + c10::optional first_neginf_logdet_index = neginf_logdet_indices[0]; - if (neginf_logdet_indices[0].size(0) == logdet.numel()) { // all log determinants are -inf (singular) + if (first_neginf_logdet_index->size(0) == logdet.numel()) { // all log determinants are -inf (singular) return singular_case_backward(grad, self); } @@ -2317,15 +2318,17 @@ Tensor slogdet_backward(const Tensor& grad_logabsdet, return nonsingular_case_backward(grad_logabsdet, self); } } else { - auto nonzero_signdet_indices = at::where(signdet); + auto nonzero_signdet_indices = at::native::toListOfOptionalTensors(at::where(signdet)); + c10::optional first_nonzero_signdet_index = nonzero_signdet_indices[0]; - if (nonzero_signdet_indices[0].size(0) == logabsdet.numel()) { // all log determinants are finite (non-singular) + if (first_nonzero_signdet_index->size(0) == logabsdet.numel()) { // all log determinants are finite (non-singular) return nonsingular_case_backward(grad_logabsdet, self); } - auto zero_signdet_indices = at::where(signdet == 0); + auto zero_signdet_indices = at::native::toListOfOptionalTensors(at::where(signdet == 0)); + c10::optional first_zero_signdet_index = zero_signdet_indices[0]; - if (zero_signdet_indices[0].size(0) == logabsdet.numel()) { // all log determinants are -inf (singular) + if (first_zero_signdet_index->size(0) == logabsdet.numel()) { // all log determinants are -inf (singular) return singular_case_backward(grad_logabsdet, self); } @@ -2877,8 +2880,8 @@ Tensor embedding_dense_double_backward(const Tensor & grad, const Tensor & indic return gg_weight.view(size); } -Tensor index_backward(Tensor zeros_like_self, TensorList indices, const Tensor& grad) { - return at::_index_put_impl_(zeros_like_self, indices, grad, true, true); +Tensor index_backward(Tensor zeros_like_self, const torch::List>& indices, const Tensor& grad) { + return at::_index_put_impl_(zeros_like_self, indices, grad, true, true); } Tensor _cudnn_ctc_loss_backward(const Tensor& grad_out, const Tensor& loss, const Tensor& raw_grad, bool zero_infinity) { diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 3814e8078b23..30736e13f58a 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -124,7 +124,7 @@ at::Tensor slogdet_backward(const at::Tensor& grad_logabsdet, const at::Tensor& at::Tensor log1p_backward(const at::Tensor& grad, const at::Tensor& self); at::Tensor sparse_constructor_values_backward(const at::Tensor& sparse_grad_out, const at::Tensor& indices, at::IntArrayRef values_shape); at::Tensor embedding_dense_double_backward(const at::Tensor & grad, const at::Tensor & indices, int64_t padding_idx); -at::Tensor index_backward(at::Tensor zeros_like_self, at::TensorList indices, const at::Tensor& grad); +at::Tensor index_backward(at::Tensor zeros_like_self, const torch::List>& indices, const at::Tensor& grad); at::Tensor _cudnn_ctc_loss_backward(const at::Tensor& grad_out, const at::Tensor& loss, const at::Tensor& raw_grad, bool zero_infinity); Tensor svd_backward(const std::vector &grads, const Tensor& self, diff --git a/torch/csrc/autograd/VariableTypeManual.cpp b/torch/csrc/autograd/VariableTypeManual.cpp index 0663d7f46fa8..d1f15fff3669 100644 --- a/torch/csrc/autograd/VariableTypeManual.cpp +++ b/torch/csrc/autograd/VariableTypeManual.cpp @@ -66,10 +66,6 @@ Tensor unpack_opt(const Tensor & t, const char * name, int pos) { return unpack(t, name, pos); } -c10::optional unpack_opt(const c10::optional & t, const char * name, int pos) { - return t; -} - std::vector unpack(at::TensorList tl, const char *name, int pos) { std::vector ret(tl.size()); for (size_t i = 0; i < tl.size(); ++i) { @@ -94,7 +90,7 @@ void _backward( // instead of us having to unwrap it to Tensor _gradient here. Tensor _gradient = gradient.has_value() ? *gradient : Tensor(); std::vector input_vars(inputs.begin(), inputs.end()); - torch::autograd::backward({self}, {_gradient}, std::move(keep_graph), create_graph, input_vars); + torch::autograd::backward({self}, {_gradient}, keep_graph, create_graph, input_vars); } void set_data(Tensor & self, const Tensor & new_data) { @@ -230,7 +226,6 @@ Tensor _fw_primal(const Tensor & self, int64_t level) { // We don't have an outplace copy, so this can't be generated automatically Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking) { - jit::Value* output = nullptr; // TODO: once copy is exposed in Declarations.yaml we may be able to bind // it automatically auto& self_ = unpack(self, "self", 0); @@ -282,7 +277,7 @@ Tensor& resize_( } { at::AutoNonVariableTypeMode non_var_type_mode(true); - self_.resize_(size, std::move(optional_memory_format)); + self_.resize_(size, optional_memory_format); } if (self.fw_grad(/* level */ 0).defined()) { @@ -303,7 +298,7 @@ Tensor& resize_as_( } { at::AutoNonVariableTypeMode non_var_type_mode(true); - at::resize_as_(self_, the_template_, std::move(optional_memory_format)); + at::resize_as_(self_, the_template_, optional_memory_format); } // Handle fw grad diff --git a/torch/csrc/autograd/VariableTypeUtils.h b/torch/csrc/autograd/VariableTypeUtils.h index af02de68fc27..509a12e01140 100644 --- a/torch/csrc/autograd/VariableTypeUtils.h +++ b/torch/csrc/autograd/VariableTypeUtils.h @@ -266,12 +266,31 @@ inline void check_no_requires_grad(TensorList tensors, const char* name) { } } +inline void check_no_requires_grad(const c10::List>& tensors, const char* name) { + for (c10::optional tensor : tensors) { + if (tensor.has_value()) { + check_no_requires_grad(*tensor, name); + } + } +} + // Assumed that saved tensor lists are never inplace outputs inline std::vector make_saved_variable_list(TensorList tensors) { return fmap(tensors, [](const Tensor& tensor) -> SavedVariable { return SavedVariable{tensor, false /* is output */}; }); } +// Assumed that saved tensor lists are never inplace outputs +inline std::vector make_saved_variable_list(const c10::List>& tensors) { + return fmap(tensors, [](const c10::optional& tensor) -> SavedVariable { + if (tensor.has_value()) { + return SavedVariable{*tensor, false /* is output */}; + } else { + return SavedVariable{Tensor(), false /* is output */}; + } + }); +} + inline std::vector> to_args_sizes(TensorList tensors) { std::vector> args_sizes(tensors.size()); for (size_t i = 0; i < tensors.size(); ++i) { diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index ca419522dff8..d86073a7af79 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -52,7 +52,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) { .value("CUDA", ActivityType::CUDA); py::class_(m, "ProfilerConfig") - .def(py::init()); + .def(py::init()); py::class_(m, "ProfilerEvent") .def("kind", &LegacyEvent::kindStr) diff --git a/torch/csrc/autograd/profiler_legacy.cpp b/torch/csrc/autograd/profiler_legacy.cpp index 3b1d254e985b..85272677a06b 100644 --- a/torch/csrc/autograd/profiler_legacy.cpp +++ b/torch/csrc/autograd/profiler_legacy.cpp @@ -226,8 +226,10 @@ void ProfilerThreadLocalState::pushRange( evt.setSequenceNr(fn.seqNr()); evt.setFwdThreadId(fn.forwardThreadId()); evt.setScope((uint8_t)fn.scope()); - evt.setExtraArgs(saveExtraArgs(fn)); - evt.setFlops(computeFlops(std::string(fn.name().str()), evt.extraArgs())); + if (config_.with_flops) { + evt.setExtraArgs(saveExtraArgs(fn)); + evt.setFlops(computeFlops(std::string(fn.name().str()), evt.extraArgs())); + } #ifndef C10_MOBILE // backward nodes source range corresponds to the forward node // TODO: consider using C++ stack trace diff --git a/torch/csrc/autograd/profiler_legacy.h b/torch/csrc/autograd/profiler_legacy.h index 3e07c8cb541b..23169cd33450 100644 --- a/torch/csrc/autograd/profiler_legacy.h +++ b/torch/csrc/autograd/profiler_legacy.h @@ -387,16 +387,19 @@ struct TORCH_API ProfilerConfig { ProfilerState state, bool report_input_shapes = false, bool profile_memory = false, - bool with_stack = false) + bool with_stack = false, + bool with_flops = false) : state(state), report_input_shapes(report_input_shapes), profile_memory(profile_memory), - with_stack(with_stack) {} + with_stack(with_stack), + with_flops(with_flops) {} ~ProfilerConfig() = default; ProfilerState state; bool report_input_shapes; bool profile_memory; bool with_stack; + bool with_flops; // Returns IValues corresponding to ProfilerConfig struct, to be used for // serialization. diff --git a/torch/csrc/autograd/python_engine.cpp b/torch/csrc/autograd/python_engine.cpp index eee29481bea5..a9c7d709466e 100644 --- a/torch/csrc/autograd/python_engine.cpp +++ b/torch/csrc/autograd/python_engine.cpp @@ -1,7 +1,6 @@ #include #include -#include #include #include #include diff --git a/torch/csrc/jit/backends/backend_detail.h b/torch/csrc/jit/backends/backend_detail.h index 2d19f2ed8950..00f0f2f9eb44 100644 --- a/torch/csrc/jit/backends/backend_detail.h +++ b/torch/csrc/jit/backends/backend_detail.h @@ -1,5 +1,6 @@ #pragma once +#include #include namespace torch { diff --git a/torch/csrc/jit/cuda/cuda.h b/torch/csrc/jit/cuda/cuda.h new file mode 100644 index 000000000000..fa92ce22d6e4 --- /dev/null +++ b/torch/csrc/jit/cuda/cuda.h @@ -0,0 +1,179 @@ +#include +#include +#include +#include + +namespace torch { +namespace jit { + +class CUDAEvent; +// This class is a wrapper around c10::cuda::CUDAStream. +// It is needed because TorchBind does not support all of the argument types +// for c10::cuda::CUDAStream. For more details, please refer to +// c10/cuda/CUDAStream.h. +class CUDAStream final : public CustomClassHolder { + public: + CUDAStream(int64_t device = -1, int64_t priority = 0) { + constexpr int64_t PRIORITY_INDEX = 0; + stream_ = std::make_unique( + c10::cuda::getStreamFromPool(priority < PRIORITY_INDEX, device)); + } + + CUDAStream(c10::cuda::CUDAStream s) { + stream_ = std::make_unique(s); + } + + bool query() { + return stream_->query(); + } + + c10::intrusive_ptr recordEvent( + c10::intrusive_ptr event); + + void synchronize() { + stream_->synchronize(); + } + + void waitEvent(c10::intrusive_ptr event); + + void waitStream(c10::intrusive_ptr stream); + + /// Get the CUDA device index that this stream is associated with. + int64_t device_index() const { + return stream_->device_index(); + } + + /// Get the full Device that this stream is associated with. The Device + /// is guaranteed to be a CUDA device. + c10::Device device() const { + return stream_->device(); + } + + /// Return the stream ID corresponding to this particular stream. + int64_t id() const { + return stream_->id(); + } + + /// Pack a CUDAStream to uint64_t representation. + /// The CUDAStream can be unpacked using unpack(). The format of + /// the uint64_t is unspecified and may be changed. + int64_t pack() const { + return stream_->pack(); + } + + private: + std::unique_ptr stream_; + friend class CUDAEvent; +}; + +// This class is a wrapper around at::cuda::CUDAStream. +// It is needed because TorchBind does not support all of the argument types +// for at::cuda::CUDAEvent. For more details, please refer to +// aten/src/ATen/cuda/CUDAEvent.h. +class CUDAEvent final : public CustomClassHolder { + public: + CUDAEvent( + bool enable_timing = false, + bool blocking = false, + bool interprocess = false) { + int flags = cudaEventDisableTiming; + if (enable_timing) { + flags = cudaEventDefault; + } + if (blocking) { + flags |= cudaEventBlockingSync; + } + if (interprocess) { + TORCH_CHECK(!enable_timing); + flags |= cudaEventInterprocess; + } + + event_ = std::make_unique(flags); + } + + double elapsedTime(c10::intrusive_ptr end) { + return event_->elapsed_time(*end->event_); + } + + std::string ipcHandle() { + cudaIpcEventHandle_t handle; + event_->ipc_handle(&handle); + std::string str_handle((const char*)&handle, sizeof(handle)); + return str_handle; + } + + bool query() { + return event_->query(); + } + + void record(c10::intrusive_ptr stream); + + void synchronize() { + event_->synchronize(); + } + void wait(c10::intrusive_ptr stream); + + private: + void recordInternal(CUDAStream* stream); + std::unique_ptr event_; + + friend class CUDAStream; +}; + +c10::intrusive_ptr CUDAStream::recordEvent( + c10::intrusive_ptr event) { + if (!event) { + event = c10::make_intrusive(); + } + + event->recordInternal(this); + return event; +} + +void CUDAStream::waitEvent(c10::intrusive_ptr event) { + event->event_->block(*stream_); +} + +void CUDAStream::waitStream(c10::intrusive_ptr stream) { + auto ev = c10::make_intrusive(); + stream->recordEvent(ev); + waitEvent(ev); +} + +void CUDAEvent::record(c10::intrusive_ptr stream) { + event_->record(*stream->stream_); +} + +void CUDAEvent::recordInternal(CUDAStream* stream) { + event_->record(*stream->stream_); +} + +void CUDAEvent::wait(c10::intrusive_ptr stream) { + event_->block(*stream->stream_); +} + +TORCH_LIBRARY(cuda, m) { + auto stream_class = m.class_("Stream").def( + torch::init()); + auto event_class = m.class_("Event").def( + torch::init()); + + stream_class.def("query", &CUDAStream::query) + .def("record_event", &CUDAStream::recordEvent) + .def("synchronize", &CUDAStream::synchronize) + .def("wait_event", &CUDAStream::waitEvent) + .def("wait_stream", &CUDAStream::waitStream) + .def("device_index", &CUDAStream::device_index) + .def("device", &CUDAStream::device) + .def("pack", &CUDAStream::pack) + .def("id", &CUDAStream::id); + + event_class.def("elapsed_time", &CUDAEvent::elapsedTime) + .def("query", &CUDAEvent::query) + .def("record", &CUDAEvent::record) + .def("synchronize", &CUDAEvent::synchronize) + .def("wait", &CUDAEvent::wait); +}; + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/frontend/script_type_parser.cpp b/torch/csrc/jit/frontend/script_type_parser.cpp index 8b1aa58b5aff..f4c1fa2c920d 100644 --- a/torch/csrc/jit/frontend/script_type_parser.cpp +++ b/torch/csrc/jit/frontend/script_type_parser.cpp @@ -211,6 +211,13 @@ TypePtr ScriptTypeParser::parseTypeFromExprImpl(const Expr& expr) const { } } + // Check if the type is a custom class. This is done by checking + // if type_name starts with "torch.classes." + if (type_name.find("torch.classes.") == 0) { + auto custom_class_type = getCustomClass("__torch__." + type_name); + return custom_class_type; + } + throw ErrorReport(expr) << "Unknown type name '" << type_name << "'"; } else if (auto name = parseBaseTypeName(expr)) { auto itr = string_to_type_lut().find(*name); diff --git a/torch/csrc/jit/frontend/tracer.cpp b/torch/csrc/jit/frontend/tracer.cpp index 72ccd77f2220..1bab391bd393 100644 --- a/torch/csrc/jit/frontend/tracer.cpp +++ b/torch/csrc/jit/frontend/tracer.cpp @@ -103,6 +103,9 @@ void TracingState::delValue(const IValue& var) { Value* getValueTrace(const IValue& var) { return getTracingState()->getValue(var); } +Value* getOptTensorValueTrace(const c10::optional& var) { + return getValueTrace(IValue(var)); +} Value* TracingState::getValue(const IValue& var) { // allow tracing of tuples passed to List[Tensor] or Tuple[Tensor...] // arguments @@ -686,6 +689,16 @@ void addInputs( } n->addInput(list_node->output()); } +TORCH_API void addInputs( + Node* n, + const char* name, + const List>& value) { + Graph* g = n->owningGraph(); + Node* list_node = nullptr; + list_node = g->insertNode(g->createList( + OptionalType::ofTensor(), fmap(value, getOptTensorValueTrace))); + n->addInput(list_node->output()); +} void addInputs( Node* n, diff --git a/torch/csrc/jit/frontend/tracer.h b/torch/csrc/jit/frontend/tracer.h index 61d79cb3efd2..f5cbd821bda4 100644 --- a/torch/csrc/jit/frontend/tracer.h +++ b/torch/csrc/jit/frontend/tracer.h @@ -255,6 +255,10 @@ TORCH_API void addInputs( const char* name, ArrayRef value, bool allow_undefined = false); +TORCH_API void addInputs( + Node* n, + const char* name, + const List>& value); TORCH_API void addInputs( Node* n, const char* name, diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp index 0b3e4a4a7b41..1ca0f48f9e17 100644 --- a/torch/csrc/jit/ir/alias_analysis.cpp +++ b/torch/csrc/jit/ir/alias_analysis.cpp @@ -572,7 +572,8 @@ void AliasDb::analyzeImpl(Node* node) { !aliasAnalysisHasSpecialCaseFor(node->kind()), "Special cases should be handled already if we're here."); - if (node->kind().is_aten() || node->kind().is_prim()) { + if (node->kind().is_aten() || node->kind().is_prim() || + node->kind().is_cuda()) { // TODO There is nothing in the system that relies on aten:: and prim:: // ops using AliasAnalysisKind::FROM_SCHEMA or // AliasAnalysisKind::INTERNAL_SPECIAL_CASE, but this is the intended diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp index 65b410d82069..eb75928e5952 100644 --- a/torch/csrc/jit/ir/ir.cpp +++ b/torch/csrc/jit/ir/ir.cpp @@ -1079,6 +1079,11 @@ bool Node::hasSideEffects() const { case prim::rpc_sync: // It represents RPC message sent. case prim::rpc_remote: // It represents RPC message sent. case aten::wait: // It can represent RPC message received. +#ifndef __HIP_PLATFORM_HCC__ + case cuda::set_stream: + case cuda::_set_device: + case cuda::_current_device: +#endif case prim::Enter: case prim::Exit: return true; @@ -1094,7 +1099,7 @@ bool Node::hasSideEffects() const { return false; } - if (kind_.is_prim() || kind_.is_aten()) { + if (kind_.is_prim() || kind_.is_aten() || kind_.is_cuda()) { // TODO There is nothing in the system that relies on aten:: and prim:: // ops using AliasAnalysisKind::FROM_SCHEMA, // AliasAnalysisKind::INTERNAL_SPECIAL_CASE, or diff --git a/torch/csrc/jit/ir/ir.h b/torch/csrc/jit/ir/ir.h index 21f172f01465..02867b8639cd 100644 --- a/torch/csrc/jit/ir/ir.h +++ b/torch/csrc/jit/ir/ir.h @@ -72,6 +72,11 @@ using namespace ::c10::attr; namespace aten { using namespace ::c10::aten; } +namespace cuda { +#ifndef __HIP_PLATFORM_HCC__ +using namespace ::c10::cuda; +#endif +} // namespace cuda struct Function; struct MatchedSchema; diff --git a/torch/csrc/jit/mobile/module.h b/torch/csrc/jit/mobile/module.h index 8b7da739df9a..2be75c61b6b5 100644 --- a/torch/csrc/jit/mobile/module.h +++ b/torch/csrc/jit/mobile/module.h @@ -1,5 +1,6 @@ #pragma once //#include +#include #include #include diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp index 933d3bb1a867..056e23d06f02 100644 --- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -217,6 +217,32 @@ std::shared_ptr PythonModuleValue::attr( return toSugaredValue(member, m, loc, /*is_constant=*/true); } +#ifndef __HIP_PLATFORM_HCC__ +std::shared_ptr CUDAPythonModuleValue::attr( + const SourceRange& loc, + Function& m, + const std::string& field) { + // List of all the cuda operators which are supported in JIT + const std::unordered_set cuda_ops = {"current_stream", + "default_stream", + "_current_device", + "_set_device", + "device_index", + "device_count", + "set_stream"}; + + if (cuda_ops.find(field) != cuda_ops.end()) { + return std::make_shared(Symbol::cuda(field), c10::nullopt); + } + + py::object member = getattr(loc, field); + // note: is_constant = true because we consider that global properties + // on modules like math.pi or torch.float to be constants + // even though it is possible, though rare, for someone to mutate them + return toSugaredValue(member, m, loc, /*is_constant=*/true); +} +#endif + Value* ModuleValue::asValue(const SourceRange& loc, Function& m) { return self_; } @@ -938,6 +964,12 @@ std::shared_ptr toSugaredValue( if (auto callee = as_function(obj)) { return std::make_shared(callee->function_); } else if (py::isinstance(obj)) { +#ifndef USE_ROCM + std::string obj_name = py::cast(py::getattr(obj, "__name__")); + if (obj_name.compare("torch.cuda") == 0) { + return std::make_shared(obj); + } +#endif return std::make_shared(obj); } else if ( obj.ptr() == py::module::import("torch.jit").attr("_fork").ptr() || diff --git a/torch/csrc/jit/python/python_sugared_value.h b/torch/csrc/jit/python/python_sugared_value.h index b5d8f4490b3e..1edbc6c15cad 100644 --- a/torch/csrc/jit/python/python_sugared_value.h +++ b/torch/csrc/jit/python/python_sugared_value.h @@ -91,6 +91,20 @@ struct VISIBILITY_HIDDEN PythonModuleValue : public PythonValue { const std::string& field) override; }; +// Used for desugaring uses of the torch.cuda module. All the CUDA APIs with +// torch.cuda.* are resolved using CUDAPythonModuleValue. +#ifndef __HIP_PLATFORM_HCC__ +struct VISIBILITY_HIDDEN CUDAPythonModuleValue : public PythonValue { + explicit CUDAPythonModuleValue(py::object mod) + : PythonValue(std::move(mod)) {} + + std::shared_ptr attr( + const SourceRange& loc, + Function& m, + const std::string& field) override; +}; +#endif + // Represents all the parameters of a module as a List[Tensor] struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue { ConstantParameterList(Value* the_list) : the_list_(the_list) {} diff --git a/torch/csrc/jit/runtime/interpreter.h b/torch/csrc/jit/runtime/interpreter.h index 120a3ffb7507..a4bb209cd17e 100644 --- a/torch/csrc/jit/runtime/interpreter.h +++ b/torch/csrc/jit/runtime/interpreter.h @@ -5,6 +5,7 @@ #include #include +#include #include #include diff --git a/torch/csrc/jit/runtime/register_cuda_ops.cpp b/torch/csrc/jit/runtime/register_cuda_ops.cpp new file mode 100644 index 000000000000..5cf31d626dd0 --- /dev/null +++ b/torch/csrc/jit/runtime/register_cuda_ops.cpp @@ -0,0 +1,87 @@ +// This file registers special JIT operators used to implement the PyTorch CUDA +// API in TorchScript. +#ifndef __HIP_PLATFORM_HCC__ +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +namespace { + +c10::AliasAnalysisKind aliasAnalysisFromSchema() { + return c10::AliasAnalysisKind::FROM_SCHEMA; +} + +RegisterOperators const reg({ + Operator( + "cuda::current_stream(int64_t val) -> __torch__.torch.classes.cuda.Stream", + [](Stack* stack) { + auto idx = uint16_t(pop(stack).toInt()); + auto s = c10::cuda::getCurrentCUDAStream(idx); + auto st = make_custom_class(s); + push(stack, IValue(st)); + }, + aliasAnalysisFromSchema()), + Operator( + "cuda::default_stream(int64_t val) -> __torch__.torch.classes.cuda.Stream", + [](Stack* stack) { + auto idx = uint16_t(pop(stack).toInt()); + auto s = c10::cuda::getDefaultCUDAStream(idx); + auto st = make_custom_class(s); + push(stack, IValue(st)); + }, + aliasAnalysisFromSchema()), + Operator( + "cuda::_current_device() -> int", + [](Stack* stack) { + auto v = c10::cuda::current_device(); + push(stack, static_cast(v)); + }, + aliasAnalysisFromSchema()), + Operator( + "cuda::_set_device(int64_t val) -> ()", + [](Stack* stack) { + int64_t idx = -1; + pop(stack, idx); + c10::cuda::set_device(static_cast(idx)); + }, + aliasAnalysisFromSchema()), + Operator( + "cuda::device_index(Device device) -> int", + [](Stack* stack) { + auto device = pop(stack); + auto idx = device.toDevice().index(); + push(stack, idx); + }, + aliasAnalysisFromSchema()), + Operator( + "cuda::device_count() -> int", + [](Stack* stack) { push(stack, at::cuda::device_count()); }, + aliasAnalysisFromSchema()), + Operator( + "cuda::set_stream(__torch__.torch.classes.cuda.Stream stream) -> ()", + [](Stack* stack) { + auto v = pop(stack); + auto s = v.toCustomClass(); + // To set the current CUDA stream using + // c10::cuda::setCurrentCUDAStream, the jit::CUDAStream object needs + // to be converted to c10::cuda::CUDAStream. Since the latter cannot + // be returned from a class registered via TorchBind, this can only be + // achieved by packing the c10::cuda::CUDAStream instance contained + // inside the jit::CUDAStream object to a uint64_t representation, and + // unpacking it inside this operator. The unpacked stream is then used + // to set the current CUDA stream. + auto packed = s->pack(); + auto unpacked = c10::cuda::CUDAStream::unpack(packed); + c10::cuda::setCurrentCUDAStream(unpacked); + }, + aliasAnalysisFromSchema()), +}); +} // namespace +} // namespace jit +} // namespace torch +#endif diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp index f23b09dc0e74..fe75ec52046e 100644 --- a/torch/csrc/jit/runtime/register_prim_ops.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops.cpp @@ -908,7 +908,7 @@ RegisterOperators reg( TORCH_SELECTIVE_SCHEMA( "aten::index.Tensor_hacked_twin(Tensor self, Tensor[] indices) -> Tensor"), [](Stack* stack) { - auto indices = pop(stack).toTensorVector(); + auto indices = pop(stack).to>>(); auto self = pop(stack).toTensor(); auto result = at::index(self, indices); push(stack, std::move(result)); @@ -921,7 +921,7 @@ RegisterOperators reg( auto unsafe = pop(stack).toBool(); auto accumulate = pop(stack).toBool(); auto values = pop(stack).toTensor(); - auto indices = pop(stack).toTensorVector(); + auto indices = pop(stack).to>>(); auto self = pop(stack).toTensor(); auto result = at::_index_put_impl_(self, indices, values, accumulate, unsafe); @@ -934,7 +934,7 @@ RegisterOperators reg( [](Stack* stack) { auto accumulate = pop(stack).toBool(); auto values = pop(stack).toTensor(); - auto indices = pop(stack).toTensorVector(); + auto indices = pop(stack).to>>(); auto self = pop(stack).toTensor(); auto result = at::index_put_(self, indices, values, accumulate); push(stack, std::move(result)); @@ -946,7 +946,7 @@ RegisterOperators reg( [](Stack* stack) { auto accumulate = pop(stack).toBool(); auto values = pop(stack).toTensor(); - auto indices = pop(stack).toTensorVector(); + auto indices = pop(stack).to>>(); auto self = pop(stack).toTensor(); auto result = at::index_put_(self, indices, values, accumulate); push(stack, std::move(result)); diff --git a/torch/csrc/jit/runtime/static/passes.cpp b/torch/csrc/jit/runtime/static/passes.cpp index 75688cfa8880..bbaaa6683bbd 100644 --- a/torch/csrc/jit/runtime/static/passes.cpp +++ b/torch/csrc/jit/runtime/static/passes.cpp @@ -105,6 +105,23 @@ void ClipRangesGatherSigridHash(std::shared_ptr& graph) { fuse.runOnGraph(graph); } +void ClipRangesGatherRangesSigridHash( + std::shared_ptr& graph) { + std::string pattern = R"IR( + graph(%a, %b, %c, %d, %e, %f): + %y0 : Tensor = fb::clip_ranges(%b, %c) + %y1 : Tensor, %y2 : Tensor = fb::gather_ranges(%a, %y0) + %y3 : Tensor = fb::sigrid_hash(%y1, %d, %e, %f) + return (%y3, %y2))IR"; + std::string fused_pattern = R"IR( + graph(%a, %b, %c, %d, %e, %f): + %off : Tensor, %out : Tensor = fb::clip_ranges_gather_sigrid_hash_v3(%b, %a, %c, %d, %e, %f) + return (%out, %off))IR"; + SubgraphRewriter fuse; + fuse.RegisterRewritePattern(pattern, fused_pattern); + fuse.runOnGraph(graph); +} + void FuseInferenceOpsForSparseNN(std::shared_ptr& graph) { #ifdef FBCODE_CAFFE2 ConcatAddMulReplaceNaNClip(graph); @@ -112,6 +129,7 @@ void FuseInferenceOpsForSparseNN(std::shared_ptr& graph) { ConcatBatchMatMulBatchGather(graph); ClipRangesGatherRangesLengthsToOffsets(graph); ClipRangesGatherSigridHash(graph); + ClipRangesGatherRangesSigridHash(graph); #endif } diff --git a/torch/csrc/jit/runtime/vararg_functions.h b/torch/csrc/jit/runtime/vararg_functions.h index 36bef721d626..d6eba7f5d191 100644 --- a/torch/csrc/jit/runtime/vararg_functions.h +++ b/torch/csrc/jit/runtime/vararg_functions.h @@ -2,6 +2,7 @@ #include #include #include +#include #include namespace torch { diff --git a/torch/csrc/jit/tensorexpr/eval.cpp b/torch/csrc/jit/tensorexpr/eval.cpp index e60a0bd704bf..186af3ca822f 100644 --- a/torch/csrc/jit/tensorexpr/eval.cpp +++ b/torch/csrc/jit/tensorexpr/eval.cpp @@ -834,8 +834,12 @@ class SimpleIREvaluatorImpl : public IRVisitor { return std::erfc(v); case kSqrt: return std::sqrt(v); - case kRsqrt: - return 1.0f / std::sqrt(v); + case kRsqrt: { + auto rsqrt = [](TInput v) __ubsan_ignore_float_divide_by_zero__ { + return 1.0f / std::sqrt(v); + }; + return rsqrt(v); + } case kCeil: return std::ceil(v); case kFloor: diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index 999186d4c4ed..0145014ee8f5 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -1282,8 +1282,9 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { } break; case aten::rsqrt: { - return computeOneOperand( - "aten_rsqrt", v, [](const ExprHandle& a) { return rsqrt(a); }); + return computeOneOperand("aten_rsqrt", v, [](const ExprHandle& a) { + return rsqrt(promoteIntegerToDefaultType(a)); + }); } break; case aten::abs: { diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index c7fdf844945e..ee3a0bc71f2f 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -24,6 +24,7 @@ static std::unordered_map type_map = { {"double", ParameterType::DOUBLE}, {"complex", ParameterType::COMPLEX}, {"TensorList", ParameterType::TENSOR_LIST}, + {"c10::List>", ParameterType::TENSOR_LIST}, {"IntArrayRef", ParameterType::INT_LIST}, {"ArrayRef", ParameterType::FLOAT_LIST}, {"Generator", ParameterType::GENERATOR}, @@ -390,7 +391,7 @@ bool is_float_or_complex_list(PyObject* obj) { } auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj); - if (size > 0) { + if (size > 0) { PyObject* iobj = tuple ? PyTuple_GET_ITEM(obj, 0) : PyList_GET_ITEM(obj, 0); if (!THPUtils_checkDouble(iobj) && !PyComplex_Check(iobj)) { return false; diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index ccf3ba6b42c4..9fa490139cbd 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -160,6 +160,7 @@ struct PythonArgs { inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar); inline std::vector scalarlist(int i); inline std::vector tensorlist(int i); + inline torch::List> list_of_optional_tensors(int i); template inline std::array tensorlist_n(int i); inline std::vector intlist(int i); @@ -327,6 +328,22 @@ inline std::vector PythonArgs::tensorlist(int i) { return res; } +inline torch::List> PythonArgs::list_of_optional_tensors(int i) { + if (!args[i]) return torch::List>(); + auto tuple = six::isTuple(args[i]); + THPObjectPtr arg = six::maybeAsTuple(args[i]); + auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get()); + torch::List> res; + res.reserve(size); + for (int idx = 0; idx < size; idx++) { + PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) : PyList_GET_ITEM(arg.get(), idx); + // This is checked by the argument parser so it's safe to cast without checking + // if this is a tensor first + res.push_back(reinterpret_cast(obj)->cdata); + } + return res; +} + template inline std::array PythonArgs::tensorlist_n(int i) { auto res = std::array(); diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index c0da643e8554..b7e76478b9d7 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -194,14 +194,24 @@ def _store_based_barrier(rank, store, timeout): # legacy stores, we can use 'get' here instead. worker_count = store.add(store_key, 0) start = time.time() + log_time = time.time() while worker_count != world_size: time.sleep(0.01) worker_count = store.add(store_key, 0) + + # Print status periodically to keep track. + if timedelta(seconds=(time.time() - log_time)) > timedelta(seconds=10): + logging.info( + "Waiting in store based barrier to initialize process group for " + "rank: {}, key: {} (world_size={}, worker_count={}, timeout={})".format( + rank, store_key, world_size, worker_count, timeout)) + log_time = time.time() + if timedelta(seconds=(time.time() - start)) > timeout: raise RuntimeError( "Timed out initializing process group in store based barrier on " - "rank: {}, for key: {} (world_size={}, worker_count={})".format( - rank, store_key, world_size, worker_count)) + "rank: {}, for key: {} (world_size={}, worker_count={}, timeout={})".format( + rank, store_key, world_size, worker_count, timeout)) def _rank_not_in_group(group: ProcessGroup): """ diff --git a/torch/distributed/rpc/server_process_global_profiler.py b/torch/distributed/rpc/server_process_global_profiler.py index 6cd7b168ec6a..d8de89bfc937 100644 --- a/torch/distributed/rpc/server_process_global_profiler.py +++ b/torch/distributed/rpc/server_process_global_profiler.py @@ -116,6 +116,7 @@ def __enter__(self): profiler_kind, self.record_shapes, self.profile_memory, + False, False) _enable_server_process_global_profiler(profiler_config) return self diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index f2b0c5c53a99..cfd327165899 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -44,6 +44,7 @@ from torch.jit._serialization import save, load from torch.jit._fuser import optimized_execution, fuser, last_executed_optimized_graph +from torch.jit.cuda import stream from torch.jit._freeze import freeze # For backwards compatibility diff --git a/torch/jit/cuda.py b/torch/jit/cuda.py new file mode 100644 index 000000000000..16805301600b --- /dev/null +++ b/torch/jit/cuda.py @@ -0,0 +1,182 @@ +# mypy: ignore-errors + +r""" +This package adds support for JIT compilation for CUDA Streams and events, +This is similar to API's available in the eager mode +:ref:`cuda-semantics` has more details about working with CUDA. +""" + +import torch +from typing import Optional, Any +from torch import device as _device + +def get_current_device_index() -> int: + r"""Checks if there are CUDA devices available and + returns the device index of the current default CUDA device. + Returns -1 in case there are no CUDA devices available. + + Arguments: ``None`` + """ + if torch.cuda.device_count() > 0: + return torch.cuda._current_device() + return -1 + +def get_device_index(device: Optional[_device] = None, optional: bool = False, allow_cpu: bool = False) -> int: + r"""Gets the device index from :attr:`device`, which can be a torch.device + object, a Python integer, or ``None``. + + If :attr:`device` is a torch.device object, returns the device index if it + is a CUDA device. Note that for a CUDA device without a specified index, + , this will return the current default CUDA device if :attr:`optional` is ``True``. + If :attr:`allow_cpu` is ``True``,CPU devices will be accepted and ``-1`` will be + returned in this case. + + If :attr:`device` is a Python integer, it is returned as is. + + If :attr:`device` is ``None``, this will return the current default CUDA + device if :attr:`optional` is ``True``. + """ + if device is None: + if optional: + return get_current_device_index() + else: + raise ValueError('Expected a torch.device with a specified index ' + f'or an integer, but got: {device}') + device_index = -1 + if isinstance(device, str): + device = torch.device(device) + + if isinstance(device, torch.device): + if not allow_cpu and device.type == 'cpu': + raise ValueError(f'Expected a non cpu device, but got: {device}') + device_index = -1 if device.type == 'cpu' else torch.cuda.device_index(device) + + if isinstance(device, int): + device_index = device + + return device_index + +class device(object): + r"""Context-manager that changes the selected device. + This is similar to device (torch.device or int), but has been + introduced for JIT compatibility. + Arguments: + device (torch.device or int): device index to select. It's a no-op if + this argument is a negative integer or ``None``. + """ + def __init__(self, device: Optional[_device]): + self.idx = -1 + self.prev_idx = -1 + self.device = device + + def __enter__(self): + self.idx = get_device_index(self.device, optional=True) + + if self.idx == -1: + return + self.prev_idx = torch.cuda._current_device() + + if self.prev_idx != self.idx: + torch.cuda._set_device(self.idx) + + def __exit__(self, type: Any, value: Any, traceback: Any): + if self.prev_idx != self.idx: + torch.cuda._set_device(self.prev_idx) + +class StreamContext(object): + r"""Context-manager that selects a given stream. + All CUDA kernels queued within its context will be enqueued on a selected + stream. + Arguments: + StreamContext (Stream): selected stream. This manager is a no-op if it's + ``None``. + .. note:: Streams are per-device. If the selected stream is not on the + current device, this function will also change the current device to + match the stream. + """ + cur_stream : Optional['torch.classes.cuda.Stream'] + + def __init__(self, stream: Optional['torch.classes.cuda.Stream']): + self.idx = -1 + self.stream = stream + # Initialize the below streams to default stream on the current device + self.device_index = get_current_device_index() + self.src_prev_stream = torch.cuda.default_stream(self.device_index) + self.dst_prev_stream = torch.cuda.default_stream(self.device_index) + + def __enter__(self): + self.idx = get_device_index(device=None, optional=True) + # If there is no CUDA device available, return + if self.idx == -1: + return + + # Local cur_stream variable for type refinement + cur_stream = self.stream + # Return if stream is None + if cur_stream is None: + return + self.src_prev_stream = torch.cuda.current_stream(self.idx) + # If the stream is not on the current device, then change the device + # and set the current stream on the device + if self.src_prev_stream.device_index() != cur_stream.device_index(): + with device(cur_stream.device()): + self.dst_prev_stream = torch.cuda.current_stream(cur_stream.device_index()) + torch.cuda._set_device(cur_stream.device_index()) + torch.cuda.set_stream(cur_stream) + + def __exit__(self, type: Any, value: Any, traceback: Any): + # Local cur_stream variable for type refinement + cur_stream = self.stream + # If stream is None or no CUDA device available, return + if cur_stream is None or self.idx == -1: + return + # If the stream was not on the current device, restore the previous stream on + # the destination device and also reset the current device to the previous device. + # Set the current stream on the device to the src_prev_stream + if self.src_prev_stream.device_index() != cur_stream.device_index(): + torch.cuda.set_stream(self.dst_prev_stream) + torch.cuda._set_device(self.idx) + torch.cuda.set_stream(self.src_prev_stream) + +def stream(stream: Optional['torch.classes.cuda.Stream']) -> StreamContext: + r"""Wrapper around the Context-manager that selects a given stream. + All CUDA kernels queued within its context will be enqueued on a selected + stream. + Arguments: + stream (Stream): selected stream. This manager is a no-op if it's + ``None``. + """ + return StreamContext(stream) + +def Stream(device: int = -1, priority: int = 0) -> 'torch.classes.cuda.Stream': + r"""Wrapper around a CUDA stream. + A CUDA stream is a linear sequence of execution that belongs to a specific + device, independent from other streams. See :ref:`cuda-semantics` for + details. + Arguments: + device(int, optional): a device on which to allocate + the stream. If :attr:`device` is ``None`` (default) or a negative + integer, this will use the current device. + priority(int, optional): priority of the stream. Can be either + -1 (high priority) or 0 (low priority). By default, streams have + priority 0. + .. note:: Although CUDA versions >= 11 support more than two levels of + priorities, in PyTorch, we only support two levels of priorities. + """ + return torch.classes.cuda.Stream(device, priority) + +def Event(enable_timing: bool = False, blocking: bool = False, interprocess: bool = False) -> 'torch.classes.cuda.Event': + r"""Wrapper around a CUDA event. + CUDA events are synchronization markers that can be used to monitor the + device's progress, to accurately measure timing, and to synchronize CUDA + streams. + Arguments: + enable_timing (bool, optional): indicates if the event should measure time + (default: ``False``) + blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``) + interprocess (bool): if ``True``, the event can be shared between processes + (default: ``False``) + .. _CUDA Event Documentation: + https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html + """ + return torch.classes.cuda.Event(enable_timing, blocking, interprocess) diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index 8a16c8c27808..0c5258615bfd 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -365,11 +365,11 @@ class SiLU(Module): \text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.} .. note:: - See `Gaussian Error Linear Units (GELUs) `_ - where the SiLU (Sigmoid Linear Unit) was originally coined, and see - `Sigmoid-Weighted Linear Units for Neural Network Function Approximation - in Reinforcement Learning `_ and `Swish: - a Self-Gated Activation Function `_ + See `Gaussian Error Linear Units (GELUs) `_ + where the SiLU (Sigmoid Linear Unit) was originally coined, and see + `Sigmoid-Weighted Linear Units for Neural Network Function Approximation + in Reinforcement Learning `_ and `Swish: + a Self-Gated Activation Function `_ where the SiLU was experimented with later. Shape: @@ -937,8 +937,7 @@ def forward(self, query, key, value, key_padding_mask=None, attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch. - Shape: - - Inputs: + Shapes for inputs: - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is @@ -949,15 +948,17 @@ def forward(self, query, key, value, key_padding_mask=None, If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*\text{num_heads}, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + - attn_mask: if a 2D mask: :math:`(L, S)` where L is the target sequence length, S is the + source sequence length. + + If a 3D mask: :math:`(N\cdot\text{num\_heads}, L, S)` where N is the batch size, L is the target sequence + length, S is the source sequence length. ``attn_mask`` ensure that position i is allowed to attend + the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight. - - Outputs: + Shapes for outputs: - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - attn_output_weights: :math:`(N, L, S)` where N is the batch size, diff --git a/torch/nn/modules/flatten.py b/torch/nn/modules/flatten.py index c06b7a5534f6..dd491ba99620 100644 --- a/torch/nn/modules/flatten.py +++ b/torch/nn/modules/flatten.py @@ -2,7 +2,7 @@ from typing import Tuple, Union from torch import Tensor -from torch import Size +from torch.types import _size class Flatten(Module): @@ -53,8 +53,8 @@ class Unflatten(Module): be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively. * :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be - a `tuple` of ints or `torch.Size` for `Tensor` input or a `NamedShape` (tuple of `(name, size)` tuples) - for `NamedTensor` input. + a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape` + (tuple of `(name, size)` tuples) for `NamedTensor` input. Shape: - Input: :math:`(N, *dims)` @@ -62,7 +62,7 @@ class Unflatten(Module): Args: dim (Union[int, str]): Dimension to be unflattened - unflattened_size (Union[torch.Size, NamedShape]): New shape of the unflattened dimension + unflattened_size (Union[torch.Size, Tuple, List, NamedShape]): New shape of the unflattened dimension Examples: >>> input = torch.randn(2, 50) @@ -71,7 +71,7 @@ class Unflatten(Module): >>> nn.Linear(50, 50), >>> nn.Unflatten(1, (2, 5, 5)) >>> ) - >>> output = m(output) + >>> output = m(input) >>> output.size() torch.Size([2, 2, 5, 5]) >>> # With torch.Size @@ -79,15 +79,13 @@ class Unflatten(Module): >>> nn.Linear(50, 50), >>> nn.Unflatten(1, torch.Size([2, 5, 5])) >>> ) - >>> output = m(output) + >>> output = m(input) >>> output.size() torch.Size([2, 2, 5, 5]) >>> # With namedshape (tuple of tuples) - >>> m = nn.Sequential( - >>> nn.Linear(50, 50), - >>> nn.Unflatten('features', (('C', 2), ('H', 50), ('W',50))) - >>> ) - >>> output = m(output) + >>> input = torch.randn(2, 50, names=('N', 'features')) + >>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5))) + >>> output = unflatten(input) >>> output.size() torch.Size([2, 2, 5, 5]) """ @@ -95,9 +93,9 @@ class Unflatten(Module): __constants__ = ['dim', 'unflattened_size'] dim: Union[int, str] - unflattened_size: Union[Size, NamedShape] + unflattened_size: Union[_size, NamedShape] - def __init__(self, dim: Union[int, str], unflattened_size: Union[Size, NamedShape]) -> None: + def __init__(self, dim: Union[int, str], unflattened_size: Union[_size, NamedShape]) -> None: super(Unflatten, self).__init__() if isinstance(dim, int): @@ -121,7 +119,7 @@ def _require_tuple_tuple(self, input): "but found type {}".format(type(input).__name__)) def _require_tuple_int(self, input): - if (isinstance(input, tuple)): + if (isinstance(input, (tuple, list))): for idx, elem in enumerate(input): if not isinstance(elem, int): raise TypeError("unflattened_size must be tuple of ints, " + diff --git a/torch/nn/quantizable/__init__.py b/torch/nn/quantizable/__init__.py new file mode 100644 index 000000000000..270dcebaa5f4 --- /dev/null +++ b/torch/nn/quantizable/__init__.py @@ -0,0 +1 @@ +from .modules import * diff --git a/torch/nn/quantizable/modules/__init__.py b/torch/nn/quantizable/modules/__init__.py new file mode 100644 index 000000000000..b3480b717a2d --- /dev/null +++ b/torch/nn/quantizable/modules/__init__.py @@ -0,0 +1,7 @@ +from .rnn import LSTM +from .rnn import LSTMCell + +__all__ = [ + 'LSTM', + 'LSTMCell', +] diff --git a/torch/nn/quantizable/modules/rnn.py b/torch/nn/quantizable/modules/rnn.py new file mode 100644 index 000000000000..cfe076fac16c --- /dev/null +++ b/torch/nn/quantizable/modules/rnn.py @@ -0,0 +1,403 @@ +import numbers +from typing import Optional, Tuple +import warnings + +import torch +from torch import Tensor + +""" +We will recreate all the RNN modules as we require the modules to be decomposed +into its building blocks to be able to observe. +""" + +class LSTMCell(torch.nn.Module): + r"""A quantizable long short-term memory (LSTM) cell. + + For the description and the argument types, please, refer to :class:`~torch.nn.LSTMCell` + + Examples:: + + >>> import torch.nn.quantizable as nnqa + >>> rnn = nnqa.LSTMCell(10, 20) + >>> input = torch.randn(3, 10) + >>> hx = torch.randn(3, 20) + >>> cx = torch.randn(3, 20) + >>> output = [] + >>> for i in range(6): + hx, cx = rnn(input[i], (hx, cx)) + output.append(hx) + """ + _FLOAT_MODULE = torch.nn.LSTMCell + + def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True): + super().__init__() + self.input_size = input_dim + self.hidden_size = hidden_dim + self.bias = bias + + self.igates = torch.nn.Linear(input_dim, 4 * hidden_dim, bias=bias) + self.hgates = torch.nn.Linear(hidden_dim, 4 * hidden_dim, bias=bias) + self.gates = torch.nn.quantized.FloatFunctional() + + self.fgate_cx = torch.nn.quantized.FloatFunctional() + self.igate_cgate = torch.nn.quantized.FloatFunctional() + self.fgate_cx_igate_cgate = torch.nn.quantized.FloatFunctional() + + self.ogate_cy = torch.nn.quantized.FloatFunctional() + + def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]: + if hidden is None or hidden == (None, None): + hidden = self.initialize_hidden(x.shape[0], x.is_quantized) + hx, cx = hidden + + igates = self.igates(x) + hgates = self.hgates(hx) + gates = self.gates.add(igates, hgates) + + input_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1) + + input_gate = torch.sigmoid(input_gate) + forget_gate = torch.sigmoid(forget_gate) + cell_gate = torch.tanh(cell_gate) + out_gate = torch.sigmoid(out_gate) + + fgate_cx = self.fgate_cx.mul(forget_gate, cx) + igate_cgate = self.igate_cgate.mul(input_gate, cell_gate) + fgate_cx_igate_cgate = self.fgate_cx_igate_cgate.add(fgate_cx, igate_cgate) + cy = fgate_cx_igate_cgate + + tanh_cy = torch.tanh(cy) + hy = self.ogate_cy.mul(out_gate, tanh_cy) + return hy, cy + + def initialize_hidden(self, batch_size: int, is_quantized: bool = False) -> Tuple[Tensor, Tensor]: + h, c = torch.zeros((batch_size, self.hidden_size)), torch.zeros((batch_size, self.hidden_size)) + if is_quantized: + h = torch.quantize_per_tensor(h, scale=1.0, zero_point=0, dtype=torch.quint8) + c = torch.quantize_per_tensor(c, scale=1.0, zero_point=0, dtype=torch.quint8) + return h, c + + def _get_name(self): + return 'QuantizableLSTMCell' + + @classmethod + def from_params(cls, wi, wh, bi=None, bh=None): + """Uses the weights and biases to create a new LSTM cell. + + Args: + wi, wh: Weights for the input and hidden layers + bi, bh: Biases for the input and hidden layers + """ + assert (bi is None) == (bh is None) # Either both None or both have values + input_size = wi.shape[1] + hidden_size = wh.shape[1] + cell = cls(input_dim=input_size, hidden_dim=hidden_size, + bias=(bi is not None)) + cell.igates.weight = torch.nn.Parameter(wi) + if bi is not None: + cell.igates.bias = torch.nn.Parameter(bi) + cell.hgates.weight = torch.nn.Parameter(wh) + if bh is not None: + cell.hgates.bias = torch.nn.Parameter(bh) + return cell + + @classmethod + def from_float(cls, other): + assert type(other) == cls._FLOAT_MODULE + assert hasattr(other, 'qconfig'), "The float module must have 'qconfig'" + observed = cls.from_params(other.weight_ih, other.weight_hh, + other.bias_ih, other.bias_hh) + observed.qconfig = other.qconfig + observed.igates.qconfig = other.qconfig + observed.hgates.qconfig = other.qconfig + return observed + + +class _LSTMSingleLayer(torch.nn.Module): + r"""A single one-directional LSTM layer. + + The difference between a layer and a cell is that the layer can process a + sequence, while the cell only expects an instantaneous value. + """ + def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True): + super().__init__() + self.cell = LSTMCell(input_dim, hidden_dim, bias=bias) + + def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None): + result = [] + for xx in x: + hidden = self.cell(xx, hidden) + result.append(hidden[0]) # type: ignore + result_tensor = torch.stack(result, 0) + return result_tensor, hidden + + @classmethod + def from_params(cls, *args, **kwargs): + cell = LSTMCell.from_params(*args, **kwargs) + layer = cls(cell.input_size, cell.hidden_size, cell.bias) + layer.cell = cell + return layer + + +class _LSTMLayer(torch.nn.Module): + r"""A single bi-directional LSTM layer.""" + def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True, + batch_first: bool = False, bidirectional: bool = False): + super().__init__() + self.batch_first = batch_first + self.bidirectional = bidirectional + self.layer_fw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias) + if self.bidirectional: + self.layer_bw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias) + + def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None): + if self.batch_first: + x = x.transpose(0, 1) + if hidden is None: + hx_fw, cx_fw = (None, None) + else: + hx_fw, cx_fw = hidden + if self.bidirectional: + if hx_fw is None: + hx_bw = None + else: + hx_bw = hx_fw[1] + hx_fw = hx_fw[0] + if cx_fw is None: + cx_bw = None + else: + cx_bw = cx_fw[1] + cx_fw = cx_fw[0] + hidden_bw = hx_bw, cx_bw + hidden_fw = hx_fw, cx_fw + result_fw, hidden_fw = self.layer_fw(x, hidden_fw) + + if self.bidirectional: + x_reversed = x.flip(0) + result_bw, hidden_bw = self.layer_bw(x_reversed, hidden_bw) + result_bw = result_bw.flip(0) + + result = torch.cat([result_fw, result_bw], result_fw.dim() - 1) + h = torch.stack([hidden_fw[0], hidden_bw[0]], 0) # type: ignore + c = torch.stack([hidden_fw[1], hidden_bw[1]], 0) # type: ignore + else: + result = result_fw + h, c = hidden_fw # type: ignore + + if self.batch_first: + result.transpose_(0, 1) + + return result, (h, c) + + @classmethod + def from_float(cls, other, layer_idx=0, qconfig=None, **kwargs): + r""" + There is no FP equivalent of this class. This function is here just to + mimic the behavior of the `prepare` within the `torch.quantization` + flow. + """ + assert hasattr(other, 'qconfig') or (qconfig is not None) + + input_size = kwargs.get('input_size', other.input_size) + hidden_size = kwargs.get('hidden_size', other.hidden_size) + bias = kwargs.get('bias', other.bias) + batch_first = kwargs.get('batch_first', other.batch_first) + bidirectional = kwargs.get('bidirectional', other.bidirectional) + + layer = cls(input_size, hidden_size, bias, batch_first, bidirectional) + layer.qconfig = getattr(other, 'qconfig', qconfig) + wi = getattr(other, f'weight_ih_l{layer_idx}') + wh = getattr(other, f'weight_hh_l{layer_idx}') + bi = getattr(other, f'bias_ih_l{layer_idx}', None) + bh = getattr(other, f'bias_hh_l{layer_idx}', None) + + layer.layer_fw = _LSTMSingleLayer.from_params(wi, wh, bi, bh) + + if other.bidirectional: + wi = getattr(other, f'weight_ih_l{layer_idx}_reverse') + wh = getattr(other, f'weight_hh_l{layer_idx}_reverse') + bi = getattr(other, f'bias_ih_l{layer_idx}_reverse', None) + bh = getattr(other, f'bias_hh_l{layer_idx}_reverse', None) + layer.layer_bw = _LSTMSingleLayer.from_params(wi, wh, bi, bh) + return layer + + # Getters for the weights and biases + # Note that jit currently doesn't support the `porperty`, so if you need to + # access the weights/biases you would need to navigate manually to the + # `layer_fw.cell.igates.*`: https://github.com/pytorch/pytorch/issues/37883 + @property + def weight_ih(self): + return self.layer_fw.cell.igates.weight + + @property + def weight_hh(self): + return self.layer_fw.cell.hgates.weight + + @property + def bias_ih(self): + return self.layer_fw.cell.igates.bias + + @property + def bias_hh(self): + return self.layer_fw.cell.hgates.bias + + @property + def weight_ih_reverse(self): + assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer' + return self.layer_bw.cell.igates.weight + + @property + def weight_hh_reverse(self): + assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer' + return self.layer_bw.cell.hgates.weight + + @property + def bias_ih_reverse(self): + assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer' + return self.layer_bw.cell.igates.bias + + @property + def bias_hh_reverse(self): + assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer' + return self.layer_bw.cell.hgates.bias + + +class LSTM(torch.nn.Module): + r"""A quantizable long short-term memory (LSTM). + + For the description and the argument types, please, refer to :class:`~torch.nn.LSTM` + + Attributes: + layers : instances of the `_LSTMLayer` + + .. note:: + To access the weights and biases, you need to access them per layer. + See examples below. + + Examples:: + + >>> import torch.nn.quantizable as nnqa + >>> rnn = nnqa.LSTM(10, 20, 2) + >>> input = torch.randn(5, 3, 10) + >>> h0 = torch.randn(2, 3, 20) + >>> c0 = torch.randn(2, 3, 20) + >>> output, (hn, cn) = rnn(input, (h0, c0)) + >>> # To get the weights: + >>> print(rnn.layers[0].weight_ih) + tensor([[...]]) + >>> print(rnn.layers[0].weight_hh) + AssertionError: There is no reverse path in the non-bidirectional layer + """ + _FLOAT_MODULE = torch.nn.LSTM + + def __init__(self, input_size: int, hidden_size: int, + num_layers: int = 1, bias: bool = True, + batch_first: bool = False, dropout: float = 0., + bidirectional: bool = False): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.bias = bias + self.batch_first = batch_first + self.dropout = float(dropout) + self.bidirectional = bidirectional + self.training = False # We don't want to train using this module + num_directions = 2 if bidirectional else 1 + + if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \ + isinstance(dropout, bool): + raise ValueError("dropout should be a number in range [0, 1] " + "representing the probability of an element being " + "zeroed") + if dropout > 0: + warnings.warn("dropout option for quantizable LSTM is ignored. " + "If you are training, please, use nn.LSTM version " + "followed by `prepare` step.") + if num_layers == 1: + warnings.warn("dropout option adds dropout after all but last " + "recurrent layer, so non-zero dropout expects " + "num_layers greater than 1, but got dropout={} " + "and num_layers={}".format(dropout, num_layers)) + + layers = [_LSTMLayer(self.input_size, self.hidden_size, + self.bias, batch_first=False, + bidirectional=self.bidirectional)] + for layer in range(1, num_layers): + layers.append(_LSTMLayer(self.hidden_size, self.hidden_size, + self.bias, batch_first=False, + bidirectional=self.bidirectional)) + self.layers = torch.nn.ModuleList(layers) + + def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None): + if self.batch_first: + x = x.transpose(0, 1) + + max_batch_size = x.size(1) + num_directions = 2 if self.bidirectional else 1 + if hidden is None: + zeros = torch.zeros(num_directions, max_batch_size, + self.hidden_size, dtype=torch.float, + device=x.device) + zeros.squeeze_(0) + if x.is_quantized: + zeros = torch.quantize_per_tensor(zeros, scale=1.0, + zero_point=0, dtype=x.dtype) + hxcx = [(zeros, zeros) for _ in range(self.num_layers)] + else: + hidden_non_opt = torch.jit._unwrap_optional(hidden) + if isinstance(hidden_non_opt[0], Tensor): + hx = hidden_non_opt[0].reshape(self.num_layers, num_directions, + max_batch_size, + self.hidden_size).unbind(0) + cx = hidden_non_opt[1].reshape(self.num_layers, num_directions, + max_batch_size, + self.hidden_size).unbind(0) + hxcx = [] + for idx in range(self.num_layers): + hxcx.append((hx[idx].squeeze_(0), cx[idx].squeeze_(0))) + else: + hxcx = hidden_non_opt + + for idx in range(self.num_layers): + x, hxcx[idx] = self.layers[idx](x, hxcx[idx]) + + hx_list = [] + cx_list = [] + for idx in range(self.num_layers): + hx_list.append(hxcx[idx][0]) + cx_list.append(hxcx[idx][1]) + hx_tensor = torch.stack(hx_list) + cx_tensor = torch.stack(cx_list) + + # We are creating another dimension for bidirectional case + # need to collapse it + hx_tensor = hx_tensor.reshape(-1, *hx_tensor.shape[-2:]) + cx_tensor = cx_tensor.reshape(-1, *cx_tensor.shape[-2:]) + + if self.batch_first: + x = x.transpose(0, 1) + + return x, (hx_tensor, cx_tensor) + + def _get_name(self): + return 'QuantizableLSTM' + + @classmethod + def from_float(cls, other, qconfig=None): + assert isinstance(other, cls._FLOAT_MODULE) + assert (hasattr(other, 'qconfig') or qconfig) + observed = cls(other.input_size, other.hidden_size, other.num_layers, + other.bias, other.batch_first, other.dropout, + other.bidirectional) + observed.qconfig = getattr(other, 'qconfig', qconfig) + for idx in range(other.num_layers): + observed.layers[idx] = _LSTMLayer.from_float(other, idx, qconfig, + batch_first=False) + observed.eval() + observed = torch.quantization.prepare(observed, inplace=True) + return observed + + def from_observed(self, other): + return torch.quantization.convert(self, inplace=False, + remove_qconfig=True) diff --git a/torch/quantization/_numeric_suite_fx.py b/torch/quantization/_numeric_suite_fx.py index eb1596832c4d..aeba95bb4e8f 100644 --- a/torch/quantization/_numeric_suite_fx.py +++ b/torch/quantization/_numeric_suite_fx.py @@ -21,7 +21,7 @@ def remove_qconfig_observer_fx(model): # remove activation post process act_post_process_removed_graph = Graph() - env = {} # type: Dict[str, Any] + env: Dict[str, Any] = {} modules = dict(model.named_modules()) diff --git a/torch/quantization/fake_quantize.py b/torch/quantization/fake_quantize.py index f0ee8453557d..460b1c277a93 100644 --- a/torch/quantization/fake_quantize.py +++ b/torch/quantization/fake_quantize.py @@ -41,8 +41,7 @@ def calculate_qparams(self, **kwargs): pass @torch.jit.export - def enable_fake_quant(self, enabled=True): - # type: (bool) -> None + def enable_fake_quant(self, enabled: bool = True) -> None: self.fake_quant_enabled[0] = 1 if enabled else 0 @torch.jit.export @@ -50,8 +49,7 @@ def disable_fake_quant(self): self.enable_fake_quant(False) @torch.jit.export - def enable_observer(self, enabled=True): - # type: (bool) -> None + def enable_observer(self, enabled: bool = True) -> None: self.observer_enabled[0] = 1 if enabled else 0 @torch.jit.export diff --git a/torch/quantization/observer.py b/torch/quantization/observer.py index 32d07c939695..7addaa622962 100644 --- a/torch/quantization/observer.py +++ b/torch/quantization/observer.py @@ -877,8 +877,7 @@ def _combine_histograms(self, orig_hist = orig_hist + interpolated_histogram.to(torch.float) return orig_hist - def forward(self, x_orig): - # type: (torch.Tensor) -> torch.Tensor + def forward(self, x_orig: torch.Tensor) -> torch.Tensor: x = x_orig.detach() min_val = self.min_val max_val = self.max_val diff --git a/torch/quantization/quantize.py b/torch/quantization/quantize.py index a9417ecb80f3..1be867e0a299 100644 --- a/torch/quantization/quantize.py +++ b/torch/quantization/quantize.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn import torch.nn.quantized as nnq +import torch.nn.quantizable as nnqa from torch.nn.intrinsic import _FusedModule from .quantization_mappings import ( @@ -152,7 +153,10 @@ def insert_activation_post_process(m, special_act_post_process=None): elif needs_observation(child) and type(child) in custom_module_class_mapping: observed_child = custom_module_class_mapping[type(child)].from_float(child) setattr(module, name, observed_child) - insert_activation_post_process(observed_child) + # TODO: These are the modules that cannot be observed + # Once there are more, we should move them to a separate list + if custom_module_class_mapping[type(child)] != nnqa.LSTM: + insert_activation_post_process(observed_child) else: add_observer_(child, qconfig_propagation_list, non_leaf_module_list, device, custom_module_class_mapping) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 87d0baa895e8..0c5a5a6353df 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1001,6 +1001,16 @@ def sample_inputs_pinverse(op_info, device, dtype, requires_grad=False): SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', dtypes=[torch.bfloat16]), )), + UnaryUfuncInfo('rsqrt', + ref=lambda x: np.reciprocal(np.sqrt(x)), + domain=(0, float('inf')), + dtypes=all_types_and_complex_and(torch.bool), + dtypesIfCPU=all_types_and_complex_and(torch.bool), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half), + decorators=(precisionOverride({torch.half: 5e-2}),), + promotes_integers_to_float=True, + assert_autodiffed=True, + handles_complex_extremals=False), UnaryUfuncInfo('sqrt', ref=np.sqrt, domain=(0, float('inf')), @@ -1466,6 +1476,10 @@ def method_tests(): ('ceil', (), NO_ARGS, 'scalar', (True,)), ('rad2deg', (S, S, S), NO_ARGS), ('deg2rad', (S, S, S), NO_ARGS), + # Removing the 'rsqrt' entries leads to failure in + # test_index_fill_variable_dim_* + # TODO: Remove when fixed. + # Reference: https://github.com/pytorch/pytorch/issues/48230 ('rsqrt', torch.rand(S, S, S) + 1e-2, NO_ARGS, '', (True,)), ('rsqrt', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar', (True,)), ('rsqrt', torch.rand(S, S, S, dtype=torch.cfloat) + 1e-2, NO_ARGS, 'complex', (True,)), diff --git a/torch/testing/_internal/common_quantized.py b/torch/testing/_internal/common_quantized.py index 243cd964b96d..f14556597128 100644 --- a/torch/testing/_internal/common_quantized.py +++ b/torch/testing/_internal/common_quantized.py @@ -102,6 +102,35 @@ def _calculate_dynamic_per_channel_qparams(X, dtype): return scale, zero_point +def _snr(x, x_hat): + """Calculates the signal to noise ratio and returns the signal and noise + power, as well as the SNR in dB. + If the input is a list/tuple this function is called recursively on each + element. The result will have the same nested structure as the inputs. + + Args: + x, x_hat: Either a tensor or a nested list/tuple of tensors. + Returns: + signal, noise, SNR(in dB): Either floats or a nested list of floats + """ + if isinstance(x, (list, tuple)): + assert(len(x) == len(x_hat)) + res = [] + for idx in range(len(x)): + res.append(_snr(x[idx], x_hat[idx])) + return res + if x_hat.is_quantized: + x_hat = x_hat.dequantize() + if x.is_quantized: + x = x.dequantize() + noise = (x - x_hat).norm() + if noise == 0: + return 0.0, float('inf'), float('inf') + signal = x.norm() + snr = signal / noise + snr_db = 20 * snr.log10() + return signal, noise, snr_db + @contextmanager def override_quantized_engine(qengine): previous = torch.backends.quantized.engine diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index df6919cf65b0..bea572722ae6 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -303,11 +303,16 @@ def run_tests(argv=UNITTEST_ARGS): if IS_WINDOWS: @contextmanager - def TemporaryFileName(dir=None): + def TemporaryFileName(*args, **kwargs): # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile # opens the file, and it cannot be opened multiple times in Windows. To support Windows, # close the file after creation and try to remove it manually - f = tempfile.NamedTemporaryFile(delete=False, dir=dir) + if 'delete' in kwargs: + if kwargs['delete'] is not False: + raise UserWarning("only TemporaryFileName with delete=False is supported on Windows.") + else: + kwargs['delete'] = False + f = tempfile.NamedTemporaryFile(*args, **kwargs) try: f.close() yield f.name @@ -315,8 +320,8 @@ def TemporaryFileName(dir=None): os.unlink(f.name) else: @contextmanager # noqa: T484 - def TemporaryFileName(dir=None): - with tempfile.NamedTemporaryFile(dir=dir) as f: + def TemporaryFileName(*args, **kwargs): + with tempfile.NamedTemporaryFile(*args, **kwargs) as f: yield f.name if IS_WINDOWS: