diff --git a/CMakeLists.txt b/CMakeLists.txt index e346087c0cdb..3df73f8a3041 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -173,6 +173,8 @@ option(USE_NATIVE_ARCH "Use -march=native" OFF) cmake_dependent_option( USE_NCCL "Use NCCL" ON "USE_CUDA OR USE_ROCM;UNIX;NOT APPLE" OFF) +cmake_dependent_option(USE_RCCL "Use RCCL" ON + USE_NCCL OFF) cmake_dependent_option( USE_STATIC_NCCL "Use static NCCL" OFF "USE_NCCL" OFF) diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index fd3c95f2573b..6fedef185b21 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -72,7 +72,7 @@ file(GLOB metal_h "metal/*.h") file(GLOB metal_cpp "metal/*.cpp") file(GLOB_RECURSE native_metal_h "native/metal/*.h") file(GLOB metal_test_srcs "native/metal/mpscnn/tests/*.mm") -file(GLOB_RECURSE native_metal_srcs "native/metal/*.mm", "native/metal/*.cpp") +file(GLOB_RECURSE native_metal_srcs "native/metal/*.mm" "native/metal/*.cpp") EXCLUDE(native_metal_srcs "${native_metal_srcs}" ${metal_test_srcs}) file(GLOB metal_prepack_h "native/metal/MetalPrepackOpContext.h") file(GLOB metal_prepack_cpp "native/metal/MetalPrepackOpRegister.cpp") diff --git a/aten/src/ATen/VmapTransforms.h b/aten/src/ATen/VmapTransforms.h index 5063beeb08b0..8fa085245459 100644 --- a/aten/src/ATen/VmapTransforms.h +++ b/aten/src/ATen/VmapTransforms.h @@ -96,8 +96,17 @@ struct VmapPhysicalToLogicalMap; // The levels bitset specifies which vmap levels correspond to the batch // dimensions at the front of the tensor. In particular, the number of set bits // corresponds to the number of batch dimensions on `tensor` and the rightmost -// bit of `levels` specifies the minimum number of nested vmaps we are in at +// bit of `levels` specifies the maximum number of nested vmaps we are in at // this point in time. +// For example, given: +// physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3}) +// +// Rightmost bit of `levels` is 3 indicating the number of nested vmaps less +// than or equal to 3. +// bitset: 010100 +// ^ +// | +// levels: 012345 struct TORCH_API VmapPhysicalView { VmapPhysicalView(Tensor&& tensor, std::bitset levels) : levels_(levels), tensor_(tensor) { diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp index 320fa6294638..1223577c59c6 100644 --- a/aten/src/ATen/core/ivalue.cpp +++ b/aten/src/ATen/core/ivalue.cpp @@ -265,7 +265,7 @@ bool IValue::ptrEqual(const IValue& lhs, const IValue& rhs) { TORCH_INTERNAL_ASSERT(lhs.is_intrusive_ptr); TORCH_INTERNAL_ASSERT(rhs.is_intrusive_ptr); return lhs.tag == rhs.tag && - lhs.payload.as_intrusive_ptr == rhs.payload.as_intrusive_ptr; + lhs.payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr; } IValue IValue::equals(const IValue& rhs) const { @@ -325,17 +325,17 @@ size_t IValue::hash(const IValue& v) { case Tag::None: return 0; case Tag::Bool: - return c10::get_hash(v.payload.as_bool); + return c10::get_hash(v.payload.u.as_bool); case Tag::Double: - return c10::get_hash(v.payload.as_double); + return c10::get_hash(v.payload.u.as_double); case Tag::Tensor: // Tensor __hash__ is equivalent to `id()`, so take the pointer value of // the tensor to emulate it - return c10::get_hash(v.payload.as_int); + return c10::get_hash(v.payload.as_tensor.unsafeGetTensorImpl()); case Tag::Storage: - return c10::get_hash(v.payload.as_int); + return c10::get_hash(v.payload.u.as_int); case Tag::Int: - return c10::get_hash(v.payload.as_int); + return c10::get_hash(v.payload.u.as_int); case Tag::String: return c10::get_hash(v.toStringRef()); case Tag::Tuple: diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index 4a7e15c4008b..ca68a8df46e1 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -131,10 +131,15 @@ struct Capsule { // they are marked `@private`, which hides them on the doxygen documentation for // this page. -/// IValue (Interpreter Value) is a tagged union over the types supported by the -/// TorchScript interpreter. IValues contain their values as an -/// `IValue::Payload`, which holds primitive types (`int64_t`, `bool`, `double`, -/// `Device`), as values and all other types as a `c10::intrusive_ptr`. +/// IValue (Interpreter Value) is a tagged union over the types +/// supported by the TorchScript interpreter. IValues contain their +/// values as an `IValue::Payload`, which holds primitive types +/// (`int64_t`, `bool`, `double`, `Device`) and `Tensor` as values, +/// and all other types as a `c10::intrusive_ptr`. In order to +/// optimize performance of the destructor and related operations by +/// making the `Tensor` and `c10::intrusive_ptr` paths generate the +/// same code, we represent a null `c10::intrusive_ptr` as +/// `UndefinedTensorImpl::singleton()`, *not* `nullptr`. /// /// IValues are used as inputs to and outputs from the TorchScript interpreter. /// To retrieve the value contained within an IValue, use the `.toX()` methods, @@ -160,27 +165,35 @@ struct Capsule { struct TORCH_API IValue final { IValue(const IValue& rhs) : IValue(rhs.payload, rhs.tag, rhs.is_intrusive_ptr) { - if (is_intrusive_ptr) { - c10::raw::intrusive_ptr::incref(payload.as_intrusive_ptr); + if (is_intrusive_ptr && payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) { + c10::raw::intrusive_ptr::incref(payload.u.as_intrusive_ptr); } } - IValue(IValue&& rhs) noexcept : IValue() { - swap(rhs); + + IValue(IValue&& rhs) noexcept : tag(rhs.tag), is_intrusive_ptr(rhs.is_intrusive_ptr) { + moveFrom(std::move(rhs)); } + /// @private [doxygen private] ~IValue() { - if (is_intrusive_ptr) { - c10::raw::intrusive_ptr::decref(payload.as_intrusive_ptr); - } + destroy(); } - IValue& operator=(IValue&& rhs) & noexcept { - IValue(std::move(rhs)).swap(*this); // this also sets rhs to None + + C10_ALWAYS_INLINE IValue& operator=(IValue&& rhs) & noexcept { + if (&rhs == this) { + return *this; + } + + destroy(); + moveFrom(std::move(rhs)); return *this; } + IValue& operator=(IValue const& rhs) & { IValue(rhs).swap(*this); return *this; } + void dump() const; /** @@ -260,6 +273,13 @@ struct TORCH_API IValue final { return false; } + // Tensors should be compared based on internal storage + if (this->isTensor()) { + const auto& thisTensor = this->toTensor(); + const auto& rhsTensor = rhs.toTensor(); + return thisTensor.is_alias_of(rhsTensor); + } + if (!this->is_intrusive_ptr) { // Primitive types don't alias anything return false; @@ -267,29 +287,49 @@ struct TORCH_API IValue final { AT_ASSERT(rhs.is_intrusive_ptr); - // Tensors should be compared based on internal storage - if (this->isTensor()) { - const auto thisTensor = this->toTensor(); - const auto rhsTensor = rhs.toTensor(); - return thisTensor.is_alias_of(rhsTensor); - } - // Other types can be compared by their ptr value - return this->payload.as_intrusive_ptr == rhs.payload.as_intrusive_ptr; + return this->payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr; } /// @private [doxygen private] size_t use_count() const noexcept { + if (isTensor()) { + return payload.as_tensor.use_count(); + } + if (!is_intrusive_ptr) { return 1; } - return c10::raw::intrusive_ptr::use_count(payload.as_intrusive_ptr); + if (payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()) { + return 0; + } + return c10::raw::intrusive_ptr::use_count(payload.u.as_intrusive_ptr); } /// @private [doxygen private] void swap(IValue& rhs) noexcept { - std::swap(payload, rhs.payload); + if (isTensor() && rhs.isTensor()) { + std::swap(payload.as_tensor, rhs.payload.as_tensor); + } else if (isTensor()) { + at::Tensor t = std::move(payload.as_tensor); + // As far as I can tell, omitting the usual explicit destructor call + // is not UB in and of itself, and it's a slight perf win. The + // destructor is a no-op, because the moved-from Tensor is + // effectively an intrusive_ptr in the null state, so we don't need + // the behavior for correctness reasons either. Leaving this + // explanatory comment, including commented-out destructor call, to + // make this abundantly clear. + // + // payload.as_tensor.~Tensor(); + payload.u = rhs.payload.u; + new (&rhs.payload.as_tensor) at::Tensor(std::move(t)); + } else if (rhs.isTensor()) { + rhs.swap(*this); + return; + } else { + std::swap(payload.u, rhs.payload.u); + } std::swap(is_intrusive_ptr, rhs.is_intrusive_ptr); std::swap(tag, rhs.tag); } @@ -298,21 +338,17 @@ struct TORCH_API IValue final { // While some of these accessors could be generated through templates, // we prefer to write them manually for clarity - IValue(at::Tensor t) : tag(Tag::Tensor), is_intrusive_ptr(t.defined()) { - // Note: the undefined tensor is not refcounted, so while it - // is tagged as a tensor, is_intrusive_ptr is set to false. - // This is not an optional optimization: our incref call - // *will not* do the right thing when called on an - // undefined tensor. - payload.as_intrusive_ptr = t.unsafeReleaseTensorImpl(); + IValue(at::Tensor t) : tag(Tag::Tensor), is_intrusive_ptr(false) { + new (&payload.as_tensor) at::Tensor(std::move(t)); } bool isTensor() const { return Tag::Tensor == tag; } at::Tensor toTensor() &&; - at::Tensor toTensor() const&; + at::Tensor& toTensor() &; + const at::Tensor& toTensor() const&; at::TensorImpl* unsafeToTensorImpl() const { - return static_cast(payload.as_intrusive_ptr); + return payload.as_tensor.unsafeGetTensorImpl(); } IValue(at::Storage s) : tag(Tag::Storage), is_intrusive_ptr(static_cast(s)) { @@ -321,7 +357,7 @@ struct TORCH_API IValue final { // This is not an optional optimization: our incref call // *will not* do the right thing when called on an // undefined tensor. - payload.as_intrusive_ptr = s.unsafeReleaseStorageImpl(); + payload.u.as_intrusive_ptr = null_to_undefined_tensor(s.unsafeReleaseStorageImpl()); } bool isStorage() const { return Tag::Storage == tag; @@ -341,7 +377,7 @@ struct TORCH_API IValue final { : tag(Tag::Blob), is_intrusive_ptr(true) { // TODO (after Tensor merge) If we pass in a Blob holding a Tensor, extract // and store it as a Tensor instead. - payload.as_intrusive_ptr = blob.release(); + payload.u.as_intrusive_ptr = null_to_undefined_tensor(blob.release()); } /// @private [doxygen private] @@ -397,14 +433,14 @@ struct TORCH_API IValue final { // Double IValue(double d) : tag(Tag::Double), is_intrusive_ptr(false) { - payload.as_double = d; + payload.u.as_double = d; } bool isDouble() const { return Tag::Double == tag; } double toDouble() const { AT_ASSERT(isDouble()); - return payload.as_double; + return payload.u.as_double; } // Future @@ -433,7 +469,7 @@ struct TORCH_API IValue final { // Int IValue(int64_t i) : tag(Tag::Int), is_intrusive_ptr(false) { - payload.as_int = i; + payload.u.as_int = i; } // allow you to pass literals (3, 4) without ambiguity @@ -445,7 +481,7 @@ struct TORCH_API IValue final { int64_t toInt() const { AT_ASSERT(isInt()); - return payload.as_int; + return payload.u.as_int; } // Bool @@ -454,9 +490,9 @@ struct TORCH_API IValue final { // Initializing entire payload stops valgrind's from reporting // "jump or move depends on uninitialised value" in IValue copy constructor // See https://github.com/pytorch/pytorch/issues/37117 - payload.as_int = b; + payload.u.as_int = b; #else - payload.as_bool = b; + payload.u.as_bool = b; #endif } bool isBool() const { @@ -464,7 +500,7 @@ struct TORCH_API IValue final { } bool toBool() const { AT_ASSERT(isBool()); - return payload.as_bool; + return payload.u.as_bool; } // IntList @@ -580,7 +616,7 @@ struct TORCH_API IValue final { c10::intrusive_ptr toEnumHolder() const&; // None - IValue() : payload{0}, tag(Tag::None), is_intrusive_ptr(false) {} + IValue() : tag(Tag::None), is_intrusive_ptr(false) {} bool isNone() const { return Tag::None == tag; } @@ -616,21 +652,21 @@ struct TORCH_API IValue final { // Device IValue(c10::Device d) : tag(Tag::Device), is_intrusive_ptr(false) { - payload.as_device.type = d.type(); - payload.as_device.index = d.index(); + payload.u.as_device.type = d.type(); + payload.u.as_device.index = d.index(); } bool isDevice() const { return Tag::Device == tag; } c10::Device toDevice() const { AT_ASSERT(isDevice()); - return c10::Device(payload.as_device.type, payload.as_device.index); + return c10::Device(payload.u.as_device.type, payload.u.as_device.index); } //Stream IValue(c10::Stream stream) : tag(Tag::Stream), is_intrusive_ptr(false) { - payload.as_int = stream.pack(); + payload.u.as_int = stream.pack(); } c10::Stream toStream() &&; c10::Stream toStream() const &; @@ -659,7 +695,7 @@ struct TORCH_API IValue final { // QScheme IValue(at::QScheme qscheme) : tag(Tag::Int), is_intrusive_ptr(false) { - payload.as_int = static_cast(qscheme); + payload.u.as_int = static_cast(qscheme); } at::QScheme toQScheme() const { @@ -680,7 +716,7 @@ struct TORCH_API IValue final { // This is not an optional optimization: our incref call // *will not* do the right thing when called on an // undefined generator. - payload.as_intrusive_ptr = g.unsafeReleaseGeneratorImpl(); + payload.u.as_intrusive_ptr = null_to_undefined_tensor(g.unsafeReleaseGeneratorImpl()); } bool isGenerator() const { return Tag::Generator == tag; @@ -749,14 +785,19 @@ struct TORCH_API IValue final { const IValue& v); bool isPtrType() const { - return is_intrusive_ptr; + return (isTensor() && payload.as_tensor.defined()) || is_intrusive_ptr; } /// @private [doxygen private] const void* internalToPointer() const { TORCH_INTERNAL_ASSERT( isPtrType(), "Can only call internalToPointer() for pointer types"); - return payload.as_intrusive_ptr; + if (isTensor()) { + return payload.as_tensor.unsafeGetTensorImpl(); + } else { + return payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton() + ? payload.u.as_intrusive_ptr : nullptr; + } } TypePtr type() const; @@ -770,7 +811,7 @@ struct TORCH_API IValue final { } // If it is not a Tensor, then two mutable IValues alias each other only // if they are the same pointer. - return val.payload.as_int; + return val.payload.u.as_int; } }; @@ -800,6 +841,10 @@ struct TORCH_API IValue final { IValue deepcopy(HashAliasedIValueMap& memo) const; private: + static c10::intrusive_ptr_target* null_to_undefined_tensor(c10::intrusive_ptr_target* p) { + return p ? p : static_cast(c10::UndefinedTensorImpl::singleton()); + } + static bool ptrEqual(const IValue& lhs, const IValue& rhs); // NOTE: IValue tags are intentionally private. In the future we may encode // this value different (e.g. using NaN boxing), and this would make it more @@ -822,24 +867,77 @@ struct TORCH_API IValue final { class NullType = c10::detail::intrusive_target_default_null_type> c10::intrusive_ptr toIntrusivePtr() const; - void clearToNone() { - payload.as_int = 0; + void destroy() { + // We carefully construct this call to both 1) avoid UB by using + // the "wrong" one of as_tensor and as_intrusive_ptr and 2) enable + // the compiler to generate the same code for each case. It is + // surprisingly difficult to get this right. + if (isTensor() || is_intrusive_ptr) { + c10::intrusive_ptr_target* p = isTensor() ? payload.as_tensor.unsafeGetTensorImpl() : payload.u.as_intrusive_ptr; + c10::intrusive_ptr::reclaim(p); + // No need to make this destructor call! + // payload.as_tensor.~Tensor(); + } + } + + C10_ALWAYS_INLINE void moveFrom(IValue&& rhs) noexcept { + if (rhs.isTensor()) { + new (&payload.as_tensor) at::Tensor(std::move(rhs.payload.as_tensor)); + // As far as I can tell, omitting the usual explicit destructor call + // is not UB in and of itself, and it's a slight perf win. The + // destructor is a no-op, because the moved-from Tensor is + // effectively an intrusive_ptr in the null state, so we don't need + // the behavior for correctness reasons either. Leaving this + // explanatory comment, including commented-out destructor call, to + // make this abundantly clear. + // + // rhs.payload.as_tensor.~Tensor(); + } else { + payload.u = rhs.payload.u; + } + tag = rhs.tag; + is_intrusive_ptr = rhs.is_intrusive_ptr; + rhs.clearToNone(); + } + + void clearToNone() noexcept { + payload.u.as_int = 0; tag = Tag::None; is_intrusive_ptr = false; } union Payload { - int64_t as_int; - double as_double; - bool as_bool; - c10::intrusive_ptr_target* as_intrusive_ptr; - struct { - DeviceType type; - DeviceIndex index; - } as_device; + // We use a nested union here so that we can make the copy easy + // and efficient in the non-tensor (i.e., trivially copyable) + // case. Specifically, we do not have to do a switch-on-tag to + // figure out which union member to assign; we can just use + // TriviallyCopyablePayload::operator=. + union TriviallyCopyablePayload { + TriviallyCopyablePayload() : as_int(0) {} + int64_t as_int; + double as_double; + bool as_bool; + // Invariant: never nullptr; null state is represented as + // c10::UndefinedTensorImpl::singleton() for consistency of + // representation with Tensor. + c10::intrusive_ptr_target* as_intrusive_ptr; + struct { + DeviceType type; + DeviceIndex index; + } as_device; + } u; + at::Tensor as_tensor; + Payload() : u() {} + ~Payload() {} }; - IValue(Payload p, Tag t, bool i) : payload(p), tag(t), is_intrusive_ptr(i) {} + IValue(const Payload& p, Tag t, bool i) : tag(t), is_intrusive_ptr(i) { + if (isTensor()) { + new (&payload.as_tensor) at::Tensor(p.as_tensor); + } else { + payload.u = p.u; + } + } Payload payload; Tag tag; @@ -848,29 +946,36 @@ struct TORCH_API IValue final { }; struct TORCH_API WeakIValue final { - WeakIValue() : payload{0}, tag(IValue::Tag::None), is_intrusive_ptr(false) {} + WeakIValue() : tag(IValue::Tag::None), is_intrusive_ptr(false) {} WeakIValue(const WeakIValue& rhs) : payload(rhs.payload), tag(rhs.tag), is_intrusive_ptr(rhs.is_intrusive_ptr) { - if (is_intrusive_ptr) { + if (is_intrusive_ptr && payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) { c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr); } } WeakIValue(const IValue& rhs) - : payload(rhs.payload), - tag(rhs.tag), + : tag(rhs.tag), is_intrusive_ptr(rhs.is_intrusive_ptr) { + if (rhs.isTensor()) { + payload.as_intrusive_ptr = rhs.unsafeToTensorImpl(); + is_intrusive_ptr = true; + } else { + payload = rhs.payload.u; + } if (is_intrusive_ptr) { - c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr); + if (payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) { + c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr); + } } } WeakIValue(WeakIValue&& rhs) noexcept : WeakIValue() { swap(rhs); } ~WeakIValue() { - if (is_intrusive_ptr) { + if (is_intrusive_ptr && payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) { c10::raw::weak_intrusive_ptr::decref(payload.as_intrusive_ptr); } } @@ -895,17 +1000,33 @@ struct TORCH_API WeakIValue final { IValue lock() const { if (!is_intrusive_ptr) { - return IValue(payload, tag, false); + IValue::Payload newPayload; + newPayload.u = payload; + return IValue(newPayload, tag, false); } - auto temp = c10::weak_intrusive_ptr::reclaim( - payload.as_intrusive_ptr); - IValue::Payload pl; - pl.as_intrusive_ptr = temp.lock().release(); - temp.release(); - if (!pl.as_intrusive_ptr) { - return IValue(); + if (IValue::Tag::Tensor == tag) { + auto temp = c10::weak_intrusive_ptr::reclaim( + static_cast(payload.as_intrusive_ptr)); + c10::intrusive_ptr ip(temp.lock()); + temp.release(); + if (!ip) { + return IValue(); + } else { + return IValue(at::Tensor(std::move(ip))); + } } else { - return IValue(pl, tag, true); + auto temp = c10::weak_intrusive_ptr::reclaim( + payload.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton() + ? nullptr + : payload.as_intrusive_ptr); + IValue::Payload pl; + pl.u.as_intrusive_ptr = temp.lock().release(); + temp.release(); + if (!pl.u.as_intrusive_ptr) { + return IValue(); + } else { + return IValue(pl, tag, true); + } } } @@ -913,7 +1034,7 @@ struct TORCH_API WeakIValue final { if (!is_intrusive_ptr) { return 1; } - auto temp = c10::weak_intrusive_ptr::reclaim( + auto temp = c10::weak_intrusive_ptr::reclaim( payload.as_intrusive_ptr); size_t result = temp.use_count(); temp.release(); @@ -924,7 +1045,7 @@ struct TORCH_API WeakIValue final { if (!is_intrusive_ptr) { return 1; } - auto temp = c10::weak_intrusive_ptr::reclaim( + auto temp = c10::weak_intrusive_ptr::reclaim( payload.as_intrusive_ptr); size_t result = temp.weak_use_count(); temp.release(); @@ -935,7 +1056,8 @@ struct TORCH_API WeakIValue final { } private: - IValue::Payload payload; + using Payload = IValue::Payload::TriviallyCopyablePayload; + Payload payload; IValue::Tag tag; bool is_intrusive_ptr; }; diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h index 89c8e669c138..b96f4b834989 100644 --- a/aten/src/ATen/core/ivalue_inl.h +++ b/aten/src/ATen/core/ivalue_inl.h @@ -48,14 +48,18 @@ struct tagged_capsule { template c10::intrusive_ptr IValue::moveToIntrusivePtr() { auto t = c10::intrusive_ptr::reclaim( - static_cast(payload.as_intrusive_ptr)); + payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton() + ? NullType::singleton() + : static_cast(payload.u.as_intrusive_ptr)); clearToNone(); return t; } template c10::intrusive_ptr IValue::toIntrusivePtr() const { auto r = c10::intrusive_ptr::reclaim( - static_cast(payload.as_intrusive_ptr)); + payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton() + ? NullType::singleton() + : static_cast(payload.u.as_intrusive_ptr)); auto p = r; r.release(); return p; @@ -131,12 +135,26 @@ inline c10::intrusive_ptr IValue::toEnumHolder() const& { } inline at::Tensor IValue::toTensor() && { AT_ASSERT(isTensor(), "Expected Tensor but got ", tagKind()); - return at::Tensor( - moveToIntrusivePtr()); + auto result = std::move(payload.as_tensor); + // As far as I can tell, omitting the usual explicit destructor call + // is not UB in and of itself, and it's a slight perf win. The + // destructor is a no-op, because the moved-from Tensor is + // effectively an intrusive_ptr in the null state, so we don't need + // the behavior for correctness reasons either. Leaving this + // explanatory comment, including commented-out destructor call, to + // make this abundantly clear. + // + // payload.as_tensor.~Tensor(); + clearToNone(); + return result; } -inline at::Tensor IValue::toTensor() const& { +inline at::Tensor& IValue::toTensor() & { AT_ASSERT(isTensor(), "Expected Tensor but got ", tagKind()); - return at::Tensor(toIntrusivePtr()); + return payload.as_tensor; +} +inline const at::Tensor& IValue::toTensor() const& { + AT_ASSERT(isTensor(), "Expected Tensor but got ", tagKind()); + return payload.as_tensor; } inline c10::Storage IValue::toStorage() && { AT_ASSERT(isStorage(), "Expected Storage but got ", tagKind()); @@ -148,10 +166,10 @@ inline c10::Storage IValue::toStorage() const& { return c10::Storage(toIntrusivePtr()); } inline c10::Stream IValue::toStream() && { - return c10::Stream::unpack(payload.as_int); + return c10::Stream::unpack(payload.u.as_int); } inline c10::Stream IValue::toStream() const& { - return c10::Stream::unpack(payload.as_int); + return c10::Stream::unpack(payload.u.as_int); } inline c10::intrusive_ptr IValue::toBlob() && { AT_ASSERT(isBlob(), "Expected Blob but got ", tagKind()); @@ -713,7 +731,8 @@ using _guarded_unsigned_long = std::conditional_t< inline const ivalue::Object& IValue::toObjectRef() const { AT_ASSERT(isObject(), "Expected Object but got ", tagKind()); - return *static_cast(payload.as_intrusive_ptr); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), "Attempted to create null reference"); + return *static_cast(payload.u.as_intrusive_ptr); } // note: when adding a DEFINE_TO case here you should also add a @@ -729,6 +748,7 @@ inline const ivalue::Object& IValue::toObjectRef() const { inline type IValue::to() const& { \ return this->method_name(); \ } + DEFINE_TO(at::Tensor, toTensor) DEFINE_TO(at::Storage, toStorage) DEFINE_TO(c10::Stream, toStream) @@ -980,8 +1000,11 @@ inline c10::List IValue::toIntList() const& { } inline std::vector IValue::toIntVector() const { AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toIntVector on null intrusive_ptr IValue"); return createVectorFromList( - static_cast(payload.as_intrusive_ptr)); + static_cast(payload.u.as_intrusive_ptr)); } inline c10::List IValue::toDoubleList() && { AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind()); @@ -993,8 +1016,11 @@ inline c10::List IValue::toDoubleList() const& { } inline std::vector IValue::toDoubleVector() const { AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toDoubleVector on null intrusive_ptr IValue"); return createVectorFromList( - static_cast(payload.as_intrusive_ptr)); + static_cast(payload.u.as_intrusive_ptr)); } inline c10::List IValue::toBoolList() && { AT_ASSERT(isBoolList(), "Expected BoolList but got ", tagKind()); @@ -1014,8 +1040,11 @@ inline c10::List IValue::toTensorList() const& { } inline std::vector IValue::toTensorVector() const { AT_ASSERT(isTensorList(), "Expected TensorList but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toTensorVector on null intrusive_ptr IValue"); return createVectorFromList( - static_cast(payload.as_intrusive_ptr)); + static_cast(payload.u.as_intrusive_ptr)); } inline c10::List IValue::toList() && { AT_ASSERT(isList(), "Expected GenericList but got ", tagKind()); @@ -1027,7 +1056,10 @@ inline c10::List IValue::toList() const& { } inline c10::ArrayRef IValue::toListRef() const { AT_ASSERT(isList(), "Expected GenericList but got ", tagKind()); - return static_cast(payload.as_intrusive_ptr) + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toListRef on null intrusive_ptr IValue"); + return static_cast(payload.u.as_intrusive_ptr) ->list; } inline c10::Dict IValue::toGenericDict() && { @@ -1049,7 +1081,7 @@ inline c10::intrusive_ptr IValue::toTuple() const& { inline IValue::IValue(c10::intrusive_ptr v) : tag(Tag::Tuple), is_intrusive_ptr(true) { - payload.as_intrusive_ptr = v.release(); + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); } template < typename... Args, @@ -1065,14 +1097,14 @@ inline IValue::IValue(const std::tuple& t) inline IValue::IValue(c10::intrusive_ptr v) : tag(Tag::String), is_intrusive_ptr(true) { - payload.as_intrusive_ptr = v.release(); + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); } inline IValue::IValue(std::string v) : IValue(ivalue::ConstantString::create(std::move(v))) {} inline IValue::IValue(c10::impl::GenericList v) : tag(Tag::GenericList), is_intrusive_ptr(true) { - payload.as_intrusive_ptr = v.impl_.release(); + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.impl_.release()); } template > @@ -1104,7 +1136,7 @@ inline IValue::IValue(std::array v) : IValue(c10::List()) { inline IValue::IValue(c10::impl::GenericDict v) : tag(Tag::GenericDict), is_intrusive_ptr(true) { - payload.as_intrusive_ptr = v.impl_.release(); + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.impl_.release()); } template inline IValue::IValue(c10::Dict v) @@ -1131,17 +1163,17 @@ inline IValue::IValue(c10::nullopt_t) : IValue() {} inline IValue::IValue(c10::intrusive_ptr v) : tag(Tag::Object), is_intrusive_ptr(true) { - payload.as_intrusive_ptr = v.release(); + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); } inline IValue::IValue(c10::intrusive_ptr v) : tag(Tag::PyObject), is_intrusive_ptr(true) { - payload.as_intrusive_ptr = v.release(); + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); } inline IValue::IValue(c10::intrusive_ptr v) : tag(Tag::Enum), is_intrusive_ptr(true) { - payload.as_intrusive_ptr = v.release(); + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); } inline IValue IValue::make_capsule( @@ -1149,7 +1181,7 @@ inline IValue IValue::make_capsule( IValue iv; iv.tag = Tag::Capsule; iv.is_intrusive_ptr = true; - iv.payload.as_intrusive_ptr = blob.release(); + iv.payload.u.as_intrusive_ptr = null_to_undefined_tensor(blob.release()); return iv; } @@ -1170,30 +1202,33 @@ IValue::IValue(c10::intrusive_ptr custom_class) { auto ivalue_obj = c10::ivalue::Object::create( c10::StrongTypePtr(nullptr, classType), /*num_slots=*/1); ivalue_obj->setSlot(0, IValue::make_capsule(std::move(custom_class))); - payload.as_intrusive_ptr = ivalue_obj.release(); + payload.u.as_intrusive_ptr = null_to_undefined_tensor(ivalue_obj.release()); tag = Tag::Object; is_intrusive_ptr = true; } inline IValue::IValue(c10::intrusive_ptr v) : tag(Tag::Future), is_intrusive_ptr(true) { - payload.as_intrusive_ptr = v.release(); + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); } inline IValue::IValue(c10::intrusive_ptr v) : tag(Tag::RRef), is_intrusive_ptr(true) { - payload.as_intrusive_ptr = v.release(); + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); } inline IValue::IValue(c10::intrusive_ptr v) : tag(Tag::Quantizer), is_intrusive_ptr(true) { - payload.as_intrusive_ptr = v.release(); + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); } inline const std::string& IValue::toStringRef() const { AT_ASSERT(isString(), "Expected String but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toStringRef on null intrusive_ptr IValue"); return static_cast( - payload.as_intrusive_ptr) + payload.u.as_intrusive_ptr) ->string(); } inline c10::optional> IValue:: @@ -1202,8 +1237,11 @@ inline c10::optional> IValue:: return c10::nullopt; } AT_ASSERT(isString(), "Expected optional but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toOptionalStringRef on null intrusive_ptr IValue"); return std::reference_wrapper( - static_cast(payload.as_intrusive_ptr) + static_cast(payload.u.as_intrusive_ptr) ->string()); } @@ -1241,15 +1279,13 @@ inline bool IValue::isSameIdentity(const IValue& rhs) const { // for bool type, do equality check return this->toBool() == rhs.toBool(); } else if (this->isTensor() && rhs.isTensor()) { - // for tensor type, just check the as_intrusive_ptr since is_intrusive_ptr - // is false for undefined tensor - return this->payload.as_intrusive_ptr == rhs.payload.as_intrusive_ptr; + return this->payload.as_tensor.is_same(rhs.payload.as_tensor); } else if (this->isTensor() && rhs.isNone()) { // special case: undefined tensor and None are the same identity - return !this->is_intrusive_ptr; + return !this->payload.as_tensor.defined(); } else if (this->isNone() && rhs.isTensor()) { // special case: undefined tensor and None are the same identity - return !rhs.is_intrusive_ptr; + return !rhs.payload.as_tensor.defined(); } else if (this->isInt() && rhs.isInt()) { return this->toInt() == rhs.toInt(); } else if (this->isDouble() && rhs.isDouble()) { @@ -1260,7 +1296,7 @@ inline bool IValue::isSameIdentity(const IValue& rhs) const { // for objects holding in IValue, do shallow compare on pointer address to // testify the identity return this->is_intrusive_ptr && rhs.is_intrusive_ptr && - this->payload.as_intrusive_ptr == rhs.payload.as_intrusive_ptr; + this->payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr; } } diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index a3ae813616e0..7d3890f582b8 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -2370,19 +2370,19 @@ struct TORCH_API AnyClassType : public Type { inline bool IValue::isDoubleList() const { // note: avoids calling type() to avoid extra referencing counting for the returned type. - return isList() && static_cast(payload.as_intrusive_ptr)->elementType->kind() == FloatType::Kind; + return isList() && static_cast(payload.u.as_intrusive_ptr)->elementType->kind() == FloatType::Kind; } inline bool IValue::isTensorList() const { - return isList() && static_cast(payload.as_intrusive_ptr)->elementType->kind() == TensorType::Kind; + return isList() && static_cast(payload.u.as_intrusive_ptr)->elementType->kind() == TensorType::Kind; } inline bool IValue::isIntList() const { - return isList() && static_cast(payload.as_intrusive_ptr)->elementType->kind() == IntType::Kind; + return isList() && static_cast(payload.u.as_intrusive_ptr)->elementType->kind() == IntType::Kind; } inline bool IValue::isBoolList() const { - return isList() && static_cast(payload.as_intrusive_ptr)->elementType->kind() == BoolType::Kind; + return isList() && static_cast(payload.u.as_intrusive_ptr)->elementType->kind() == BoolType::Kind; } template<> diff --git a/aten/src/ATen/native/Distributions.cpp b/aten/src/ATen/native/Distributions.cpp index ef0c2e2509c1..413ea32acdef 100644 --- a/aten/src/ATen/native/Distributions.cpp +++ b/aten/src/ATen/native/Distributions.cpp @@ -118,7 +118,7 @@ DEFINE_DISPATCH(bernoulli_tensor_stub); DEFINE_DISPATCH(bernoulli_scalar_stub); DEFINE_DISPATCH(cauchy_stub); DEFINE_DISPATCH(exponential_stub); -DEFINE_DISPATCH(multinomial_stub); +DEFINE_DISPATCH(multinomial_with_replacement_stub); DEFINE_DISPATCH(geometric_stub); DEFINE_DISPATCH(log_normal_stub); DEFINE_DISPATCH(uniform_stub); @@ -497,8 +497,10 @@ Tensor& multinomial_out( // Reference: // https://github.com/pytorch/pytorch/issues/11931#issuecomment-625882503 // Half is not supported on CPU. - if (!with_replacement && - !(self.device().is_cpu() && self.scalar_type() == ScalarType::Half)) { + TORCH_CHECK( + !(self.device().is_cpu() && self.scalar_type() == ScalarType::Half), + "multinomial is not implemented for half on CPU"); + if (!with_replacement) { // Sanity checks on `self`. auto is_valid = ((self.max() < INFINITY) & (self.min() >= 0)).item(); TORCH_CHECK( @@ -537,13 +539,8 @@ Tensor& multinomial_out( return result; } - multinomial_stub( - result.device().type(), - result, - self, - n_sample, - with_replacement, - gen); + multinomial_with_replacement_stub( + result.device().type(), result, self, n_sample, gen); return result; } diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 09d50356abd9..d1fadd58d38d 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -98,23 +98,25 @@ static inline void check_cat_shape_except_dim(const Tensor & first, const Tensor if (dim == dimension) { continue; } - int64_t first_dim_size = first.size(dim); - int64_t second_dim_size = second.size(dim); + int64_t first_dim_size = first.sizes()[dim]; + int64_t second_dim_size = second.sizes()[dim]; TORCH_CHECK(first_dim_size == second_dim_size, "Sizes of tensors must match except in dimension ", dimension, ". Got ", first_dim_size, " and ", second_dim_size, " in dimension ", dim, " (The offending index is ", index, ")"); } } +static bool should_skip(const Tensor& t) { + return t.numel() == 0 && t.dim() == 1; +} + Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) { // previously, size [0] tensors were the only possible empty tensors; thus, it wasn't possible // to cat empty tensors unless all the other tensors were 1-dimensional, so we allowed these tensors // to be "skipped". We maintain this behavior for backwards compatibility, but only for this specific // size (i.e. other empty sizes are not skipped). - // FIXME: warn if this is the case - bool allSkipped = true; + bool allContiguous = true; - Tensor notSkippedTensor; // Inputs cannot alias the output tensor for (int64_t i = 0; i < tensors.size(); i++) { @@ -126,19 +128,23 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) { } at::assert_no_internal_overlap(result); - auto should_skip = [](const Tensor& t) { return t.numel() == 0 && t.dim() == 1; }; - for (auto const &tensor : tensors) { - if (should_skip(tensor)) { - continue; + const Tensor* pnotSkippedTensor = [](TensorList tensors) -> const Tensor* { + for (auto const &tensor : tensors) { + if (should_skip(tensor)) { + continue; + } + // we've found a non-empty tensor + return &tensor; } - // we've found a non-empty tensor - allSkipped = false; - notSkippedTensor = tensor; - break; - } - if (allSkipped) { + return nullptr; + }(tensors); + + if (!pnotSkippedTensor) { + // FIXME: warn if this is the case -- see comment about skipped + // tensors at top of function. return result; } + const Tensor& notSkippedTensor = *pnotSkippedTensor; TORCH_CHECK(tensors.size() > 0, "expected a non-empty list of Tensors"); TORCH_CHECK(dim <= notSkippedTensor.dim(), "dimension ", dim, "out of range"); @@ -161,7 +167,7 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) { continue; } check_cat_shape_except_dim(notSkippedTensor, tensor, dim, i); - cat_dim_size += tensor.size(dim); + cat_dim_size += tensor.sizes()[dim]; if (!tensor.is_contiguous(first_tensor_mem_format)) { allContiguous = false; @@ -196,8 +202,8 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) { if (reuse_iterator && result.is_contiguous(first_tensor_mem_format) && no_type_promotion) { - auto source_slice = notSkippedTensor; - auto slice_dim_size = source_slice.size(dim); + const auto& source_slice = notSkippedTensor; + auto slice_dim_size = source_slice.sizes()[dim]; auto result_slice = result.narrow(dim, 0, slice_dim_size); auto result_slice_data = result_slice.data_ptr(); auto result_stride_bytes = result.stride(dim) * elementSize(result.scalar_type()); @@ -226,7 +232,7 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) { if (should_skip(tensor)) { continue; } - auto slice_dim_size = tensor.size(dim); + auto slice_dim_size = tensor.sizes()[dim]; auto result_slice = result.narrow(dim, offset, slice_dim_size); auto iter = TensorIteratorConfig() diff --git a/aten/src/ATen/native/UnaryOps.h b/aten/src/ATen/native/UnaryOps.h index f732cb9a0141..d92864e6fb2a 100644 --- a/aten/src/ATen/native/UnaryOps.h +++ b/aten/src/ATen/native/UnaryOps.h @@ -77,7 +77,9 @@ DECLARE_DISPATCH(void(*)(TensorIterator&, c10::optional), random_full DECLARE_DISPATCH(void(*)(TensorIterator&, c10::optional), random_stub); DECLARE_DISPATCH(void(*)(TensorIterator&, const int64_t), polygamma_stub); DECLARE_DISPATCH(void(*)(TensorIterator&, Scalar a, Scalar b), clamp_stub); -DECLARE_DISPATCH(void(*)(Tensor&, const Tensor&, int64_t, bool, c10::optional), multinomial_stub); +DECLARE_DISPATCH( + void (*)(Tensor&, const Tensor&, int64_t, c10::optional), + multinomial_with_replacement_stub); DECLARE_DISPATCH( void (*)( TensorIterator&, diff --git a/aten/src/ATen/native/cpu/CatKernel.cpp b/aten/src/ATen/native/cpu/CatKernel.cpp index 299850407da3..f86adb8e6318 100644 --- a/aten/src/ATen/native/cpu/CatKernel.cpp +++ b/aten/src/ATen/native/cpu/CatKernel.cpp @@ -15,18 +15,20 @@ struct InputMeta { InputMeta(const Tensor& t, int64_t dim, int64_t inner) : data_ptr(t.data_ptr()) - , inner_size(t.size(dim) * inner) {} + , inner_size(t.sizes()[dim] * inner) {} }; template void cat_serial_kernel_impl(Tensor& result, TensorList tensors, int64_t dim) { - int64_t outer = result.numel() / (result.size(dim) * result.stride(dim)); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + dim >= 0 && dim < result.dim(), "dim out of range in cat_serial_kernel_impl"); + int64_t outer = result.numel() / (result.sizes()[dim] * result.strides()[dim]); scalar_t* result_data = result.data_ptr(); int64_t ninputs = tensors.size(); std::vector inputs; inputs.reserve(ninputs); for (auto const &tensor : tensors) { - inputs.emplace_back(tensor, dim, result.stride(dim)); + inputs.emplace_back(tensor, dim, result.strides()[dim]); } using Vec = vec256::Vec256; diff --git a/aten/src/ATen/native/cpu/MultinomialKernel.cpp b/aten/src/ATen/native/cpu/MultinomialKernel.cpp index 1f4a52084962..62f1d7b879ac 100644 --- a/aten/src/ATen/native/cpu/MultinomialKernel.cpp +++ b/aten/src/ATen/native/cpu/MultinomialKernel.cpp @@ -11,8 +11,12 @@ namespace at { namespace native { namespace { -template -void multinomial_apply(Tensor& result, const Tensor& self, const int64_t n_sample, const bool with_replacement, c10::optional generator) { +template +void multinomial_with_replacement_apply( + Tensor& result, + const Tensor& self, + const int64_t n_sample, + c10::optional generator) { auto gen = get_generator_or_default(generator, detail::getDefaultCPUGenerator()); // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); @@ -61,8 +65,6 @@ void multinomial_apply(Tensor& result, const Tensor& self, const int64_t n_sampl } TORCH_CHECK(sum > 0, "invalid multinomial distribution (sum of probabilities <= 0)"); - TORCH_CHECK(with_replacement || (n_categories - n_zeros >= n_sample), - "invalid multinomial distribution (with replacement=False, not enough non-negative category to sample)"); /* normalize cumulative probability distribution so that last val is 1 i.e. doesn't assume original self row sums to one */ @@ -100,45 +102,23 @@ void multinomial_apply(Tensor& result, const Tensor& self, const int64_t n_sampl /* store in result tensor (will be incremented for lua compat by wrapper) */ result_ptr[i * result_dist_stride_0 + j * result_dist_stride_1] = sample_idx; - - /* Once a sample is drawn, it cannot be drawn again. ie sample without replacement */ - if (!with_replacement && j < n_sample - 1) { - /* update cumulative distribution so that sample cannot be drawn again */ - scalar_t diff; - scalar_t new_val = 0; - scalar_t sum; - - if (sample_idx != 0) { - new_val = cum_dist_ptr[(sample_idx - 1) * cum_dist_stride_0]; - } - /* marginal cumulative mass (i.e. original probability) of sample */ - diff = cum_dist_ptr[sample_idx * cum_dist_stride_0] - new_val; - /* new sum of marginals is not one anymore... */ - sum = 1.0 - diff; - for (int64_t k = 0; k < n_categories; k++) { - new_val = cum_dist_ptr[k * cum_dist_stride_0]; - if (k >= sample_idx) { - /* remove sampled probability mass from later cumulative probabilities */ - new_val -= diff; - } - /* make total marginals sum to one */ - new_val /= sum; - cum_dist_ptr[k * cum_dist_stride_0] = new_val; - } - } } } } -static void multinomial_kernel_impl(Tensor& result, const Tensor& self, const int64_t n_sample, const bool with_replacement, c10::optional gen) { +static void multinomial_with_replacement_kernel_impl( + Tensor& result, + const Tensor& self, + const int64_t n_sample, + c10::optional gen) { AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "multinomial", [&] { - multinomial_apply(result, self, n_sample, with_replacement, gen); + multinomial_with_replacement_apply(result, self, n_sample, gen); }); } - } -REGISTER_DISPATCH(multinomial_stub, &multinomial_kernel_impl); - +REGISTER_DISPATCH( + multinomial_with_replacement_stub, + &multinomial_with_replacement_kernel_impl); } } diff --git a/aten/src/ATen/native/cuda/Dropout.cu b/aten/src/ATen/native/cuda/Dropout.cu index 67adbaabbb84..c3e456d97056 100644 --- a/aten/src/ATen/native/cuda/Dropout.cu +++ b/aten/src/ATen/native/cuda/Dropout.cu @@ -57,6 +57,12 @@ fused_dropout_kernel_vec(at::cuda::detail::TensorInfo a, accscalar_t pinv = accscalar_t(1)/p; + // Helps align the total number of times curand_uniform4 is called by each thread for the same totalElements + // in the vec=2 and vec=4 cases. + bool gridxvec_loop_state = 0; + + float4 rand; + // Note: Vectorized loads means we'll stride each thread by an additional VEC factor, as we'll load VEC elements at a time for (IndexType linearIndex = idx * VEC; linearIndex < totalElements; @@ -69,12 +75,21 @@ fused_dropout_kernel_vec(at::cuda::detail::TensorInfo a, //curand_uniform_double was pure evil anyway, not doing what it promises, and there's nothing for halfs, so generate float for everything // Note: need a new set of random values per 4 elements -- we'll handle VEC elements in this thread, so need ceil(VEC / 4) // sets of rand. - float4 rand = curand_uniform4(&state); + if ((VEC == 4) || (gridxvec_loop_state == 0)) { + rand = curand_uniform4(&state); + } else { + // sets up the last two values we generated last iteration to be used this iteration. + rand.x = rand.z; + rand.y = rand.w; + gridxvec_loop_state ^= 1; + } rand.x = rand.x < p; rand.y = rand.y < p; - rand.z = rand.z < p; - rand.w = rand.w < p; + if (VEC == 4) { + rand.z = rand.z < p; + rand.w = rand.w < p; + } // Note: We explicitly check for is_contiguous() before launching the vectorized kernel // and replace IndexToOffset call with linearIndex to allow vectorization of NHWC (or other) diff --git a/aten/src/ATen/native/cuda/MultinomialKernel.cu b/aten/src/ATen/native/cuda/MultinomialKernel.cu index 3d59617903b4..cc74848b632a 100644 --- a/aten/src/ATen/native/cuda/MultinomialKernel.cu +++ b/aten/src/ATen/native/cuda/MultinomialKernel.cu @@ -300,7 +300,11 @@ sampleMultinomialOnce(int64_t* dest, } } -void multinomial_kernel_impl(Tensor& result, const Tensor& self, const int64_t n_sample, const bool with_replacement, c10::optional generator) { +void multinomial_with_replacement_kernel_impl( + Tensor& result, + const Tensor& self, + const int64_t n_sample, + c10::optional generator) { auto gen = get_generator_or_default(generator, cuda::detail::getDefaultCUDAGenerator()); int inputSize = self.dim(); @@ -371,7 +375,6 @@ void multinomial_kernel_impl(Tensor& result, const Tensor& self, const int64_t n PhiloxCudaState rng_engine_inputs; - if (with_replacement) { // Binary search is warp divergent (so effectively we're running // with just a single thread), but for better utilization, // we need each block to have at least 4 warps. @@ -402,7 +405,6 @@ void multinomial_kernel_impl(Tensor& result, const Tensor& self, const int64_t n prefixSum.data_ptr(), normDist.data_ptr()); C10_CUDA_KERNEL_LAUNCH_CHECK(); - } } }); @@ -412,6 +414,7 @@ void multinomial_kernel_impl(Tensor& result, const Tensor& self, const int64_t n } } -REGISTER_DISPATCH(multinomial_stub, &multinomial_kernel_impl); - +REGISTER_DISPATCH( + multinomial_with_replacement_stub, + &multinomial_with_replacement_kernel_impl); }} diff --git a/aten/src/ATen/native/sparse/SparseTensorMath.cpp b/aten/src/ATen/native/sparse/SparseTensorMath.cpp index 9bb679beb3d0..6c3298b72e75 100644 --- a/aten/src/ATen/native/sparse/SparseTensorMath.cpp +++ b/aten/src/ATen/native/sparse/SparseTensorMath.cpp @@ -650,7 +650,7 @@ Tensor& add_out_dense_sparse_cpu(Tensor& r, const Tensor& dense, const SparseTen dstBuffer.add_(srcBuffer, value); } } else { - AT_DISPATCH_ALL_TYPES( + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, commonDtype, "add_dense_sparse", [&] { add_dense_sparse_worker_cpu(resultBuffer, value, sparse, indices, valuesBuffer); }); diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu index c8366f71618e..fce3446816e7 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu @@ -338,8 +338,8 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, const SparseT if (sparse.dense_dim() == 0) { TORCH_CHECK(cuda::getApplyGrid(nnz, grid, curDevice), "add: Argument #0: tensor too large or too many dimensions"); - AT_DISPATCH_ALL_TYPES_AND2( - at::ScalarType::Half, at::ScalarType::BFloat16, commonDtype, "add_out_dense_sparse_cuda", [&] { + AT_DISPATCH_ALL_TYPES_AND3( + at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16, commonDtype, "add_out_dense_sparse_cuda", [&] { apply::sparseElementwiseKernelScalar, uint64_t, scalar_t> <<>>( TensorCAddOp(value.to()), diff --git a/aten/src/ATen/templates/RegisterDispatchKey.cpp b/aten/src/ATen/templates/RegisterDispatchKey.cpp index e923f6d73bd0..ed4359c6883e 100644 --- a/aten/src/ATen/templates/RegisterDispatchKey.cpp +++ b/aten/src/ATen/templates/RegisterDispatchKey.cpp @@ -37,10 +37,13 @@ namespace at { -namespace { - ${dispatch_definitions} +// NB: TORCH_LIBRARY_IMPL must be in an anonymous namespace to avoid +// ambiguity with conflicting identifiers that may have been defined in +// at namespace already. +namespace { + TORCH_LIBRARY_IMPL(aten, ${DispatchKey}, m) { ${dispatch_registrations} } diff --git a/aten/src/ATen/test/ivalue_test.cpp b/aten/src/ATen/test/ivalue_test.cpp index 14e75205aa66..a0e2648758ff 100644 --- a/aten/src/ATen/test/ivalue_test.cpp +++ b/aten/src/ATen/test/ivalue_test.cpp @@ -51,6 +51,91 @@ TEST(IValueTest, Basic) { ASSERT_EQ(tv.use_count(), 2); } +static std::array makeSampleIValues() { + return { at::rand({3, 4}), "hello", 42, true, 1.5 }; +} + +static std::array makeMoreSampleIValues() { + return { at::rand({3, 4}), "goodbye", 23, false, 0.5 }; +} + +// IValue::operator== doesn't seem to work on Tensors. +#define EXPECT_IVALUE_EQ(a, b) \ + EXPECT_EQ((a).isTensor(), (b).isTensor()); \ + if ((a).isTensor()) { \ + EXPECT_TRUE(a.toTensor().equal(b.toTensor())); \ + } else { \ + EXPECT_EQ(a, b); \ + } + +TEST(IValueTest, Swap) { + // swap() has the following 3 cases: tensor, intrusive_ptr, or + // neither. Exercise all pairs of the three. + + auto sampleInputs = makeSampleIValues(); + auto sampleTargets = makeMoreSampleIValues(); + for (const auto& input: sampleInputs) { + for (const auto& target: sampleTargets) { + IValue a(input); + IValue b(target); + EXPECT_IVALUE_EQ(a, input); + EXPECT_IVALUE_EQ(b, target); + a.swap(b); + EXPECT_IVALUE_EQ(a, target); + EXPECT_IVALUE_EQ(b, input); + } + } +} + +TEST(IValueTest, CopyConstruct) { + auto sampleInputs = makeSampleIValues(); + for (const IValue& v: sampleInputs) { + IValue copy(v); + EXPECT_IVALUE_EQ(copy, v); + } +} + +TEST(IValueTest, MoveConstruct) { + auto sampleInputs = makeSampleIValues(); + for (const IValue& v: sampleInputs) { + IValue source(v); + IValue target(std::move(source)); + EXPECT_IVALUE_EQ(target, v); + EXPECT_TRUE(source.isNone()); + } +} + +TEST(IValueTest, CopyAssign) { + auto sampleInputs = makeSampleIValues(); + auto sampleTargets = makeMoreSampleIValues(); + + for (const IValue& input: sampleInputs) { + for (const IValue& target: sampleTargets) { + IValue copyTo(target); + IValue copyFrom(input); + copyTo = copyFrom; + EXPECT_IVALUE_EQ(copyTo, input); + EXPECT_IVALUE_EQ(copyFrom, input); + EXPECT_IVALUE_EQ(copyTo, copyFrom); + } + } +} + +TEST(IValueTest, MoveAssign) { + auto sampleInputs = makeSampleIValues(); + auto sampleTargets = makeMoreSampleIValues(); + + for (const IValue& input: sampleInputs) { + for (const IValue& target: sampleTargets) { + IValue moveTo(target); + IValue moveFrom(input); + moveTo = std::move(moveFrom); + EXPECT_IVALUE_EQ(moveTo, input); + EXPECT_TRUE(moveFrom.isNone()); + } + } +} + TEST(IValueTest, Tuple) { std::tuple t = std::make_tuple(123, at::randn({1})); auto iv = IValue(t); @@ -318,5 +403,137 @@ TEST(IValueTest, EnumEquality) { ); } +TEST(IValueTest, isPtrType) { + IValue tensor(at::rand({3, 4})); + IValue undefinedTensor((at::Tensor())); + IValue integer(42); + IValue str("hello"); + + EXPECT_TRUE(tensor.isPtrType()); + EXPECT_FALSE(undefinedTensor.isPtrType()); + EXPECT_FALSE(integer.isPtrType()); + EXPECT_TRUE(str.isPtrType()); +} + +TEST(IValueTest, isAliasOf) { + auto sampleIValues = makeSampleIValues(); + for (auto& iv: sampleIValues) { + for (auto& iv2: sampleIValues) { + if (&iv == &iv2 && iv.isPtrType()) { + EXPECT_TRUE(iv.isAliasOf(iv2)); + } else { + EXPECT_FALSE(iv.isAliasOf(iv2)); + } + } + } +} + +TEST(IValueTest, internalToPointer) { + IValue tensor(at::rand({3, 4})); + IValue str("hello"); + + EXPECT_EQ(tensor.internalToPointer(), tensor.unsafeToTensorImpl()); + EXPECT_NE(str.internalToPointer(), nullptr); + + IValue nullStr((c10::intrusive_ptr())); + ASSERT_TRUE(nullStr.isString()); + EXPECT_EQ(nullStr.internalToPointer(), nullptr); +} + +TEST(IValueTest, IdentityComparisonAndHashing) { + at::Tensor t1 = at::rand({3, 4}); + at::Tensor t2 = at::rand({3, 4}); + IValue tv1(t1), tv2(t2); + IValue tv1b(t1); + + EXPECT_EQ(tv1.hash(), tv1b.hash()); + EXPECT_NE(tv1.hash(), tv2.hash()); + + EXPECT_TRUE(tv1.is(tv1)); + EXPECT_TRUE(tv1.is(tv1b)); + EXPECT_TRUE(tv1b.is(tv1)); + EXPECT_TRUE(tv2.is(tv2)); + + EXPECT_FALSE(tv1.is(tv2)); + EXPECT_FALSE(tv2.is(tv1)); + + IValue none; + IValue undefinedTensor((at::Tensor())); + + EXPECT_TRUE(none.is(undefinedTensor)); + EXPECT_TRUE(undefinedTensor.is(none)); + + // Is this a bug? We should probably have a is b => a.hash() == b.hash() + EXPECT_NE(none.hash(), undefinedTensor.hash()); + + auto sampleIValues = makeSampleIValues(); + auto sampleIValues2 = makeSampleIValues(); + auto moreSampleIValues = makeMoreSampleIValues(); + + ASSERT_EQ(sampleIValues.size(), moreSampleIValues.size()); + for (int ii = 0; ii < sampleIValues.size(); ++ii) { + // Constant strings will have the same pointer value. + if (sampleIValues[ii].isPtrType() && !sampleIValues[ii].isString()) { + EXPECT_NE(sampleIValues[ii].hash(), sampleIValues2[ii].hash()); + } else { + EXPECT_EQ(sampleIValues[ii].hash(), sampleIValues2[ii].hash()); + } + EXPECT_NE(sampleIValues[ii].hash(), moreSampleIValues[ii].hash()); + } +} + +TEST(IValueTest, getSubValues) { + // Scalars have no subvalues. + IValue integer(42), float_(1.5); + + IValue::HashAliasedIValues subvalues; + + integer.getSubValues(subvalues); + EXPECT_TRUE(subvalues.empty()); + + subvalues.clear(); + + float_.getSubValues(subvalues); + EXPECT_TRUE(subvalues.empty()); + + subvalues.clear(); + + at::Tensor t1(at::rand({3, 4})), t2(at::rand({3, 4})); + IValue tv1(t1), tv2(t2); + IValue list(std::vector{t1, t2}); + IValue tuple(ivalue::Tuple::create({tv1, tv2})); + + std::unordered_map m; + m[1] = t1; + m[2] = t2; + + IValue dict(std::move(m)); + + auto objType = ClassType::create(nullopt, {}); + objType->addAttribute("t1", tv1.type()); + objType->addAttribute("t2", tv2.type()); + + auto o = ivalue::Object::create(StrongTypePtr(nullptr, objType), 2); + o->setSlot(0, tv1); + o->setSlot(1, tv2); + + IValue object(o); + tv1.getSubValues(subvalues); + EXPECT_EQ(subvalues.size(), 1); + EXPECT_EQ(subvalues.count(tv1), 1); + + subvalues.clear(); + + for (auto& container: {list, tuple, dict, object}) { + container.getSubValues(subvalues); + EXPECT_EQ(subvalues.size(), 3); + EXPECT_EQ(subvalues.count(container), 1); + EXPECT_EQ(subvalues.count(tv1), 1); + EXPECT_EQ(subvalues.count(tv2), 1); + + subvalues.clear(); + } +} + // TODO(gmagogsfm): Add type conversion test? } // namespace c10 diff --git a/c10/util/intrusive_ptr.h b/c10/util/intrusive_ptr.h index 637db95991f2..790d97ee3994 100644 --- a/c10/util/intrusive_ptr.h +++ b/c10/util/intrusive_ptr.h @@ -206,7 +206,7 @@ class intrusive_ptr final { "NullType must have a constexpr singleton() method"); #endif static_assert( - std::is_same::value, + std::is_base_of::type>::value, "NullType::singleton() must return a element_type* pointer"); TTarget* target_; @@ -509,7 +509,7 @@ class weak_intrusive_ptr final { "NullType must have a constexpr singleton() method"); #endif static_assert( - std::is_same::value, + std::is_base_of::type>::value, "NullType::singleton() must return a element_type* pointer"); TTarget* target_; diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index d014bd31f02e..7965b3cc88a4 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -573,7 +573,16 @@ def forward(self, x): m = convert_fx(m) m(tensor_input) - def test_standalone_module(self): + def _test_standalone_module( + self, + interface_config, + prepare_count_check, + standalone_prepare_count_check, + convert_count_check, + standalone_convert_count_check): + """ Test standalone module with different quantized input/quantized output + configurations + """ class StandaloneModule(torch.nn.Module): def __init__(self): super().__init__() @@ -613,45 +622,32 @@ def forward(self, x): original_ref_m.conv2.weight = torch.nn.Parameter(original_m.standalone.conv.weight.detach()) original_ref_m.conv2.bias = torch.nn.Parameter(original_m.standalone.conv.bias.detach()) - qconfig_dict = {"": default_qconfig} - config_name = {"standalone_module_name": [("standalone", None, None)]} - config_class = {"standalone_module_class": [(StandaloneModule, None, None)]} - for prepare_config in [config_name, config_class]: + for is_name in [True, False]: + if is_name: + prepare_config = { + "standalone_module_name": [("standalone", None, interface_config)] + } + else: + prepare_config = { + "standalone_module_class": [(StandaloneModule, None, interface_config)] + } + original_m_copy = copy.deepcopy(original_m) original_ref_m_copy = copy.deepcopy(original_ref_m) + + qconfig_dict = {"": default_qconfig} # check prepared model m = prepare_fx( original_m_copy, qconfig_dict, prepare_custom_config_dict=prepare_config) # calibration m(data) - # input and output of first conv, observer for standalone module - # will be inserted in the standalone module itself - count_check = { - ns.call_module(torch.quantization.MinMaxObserver): 2 - } - self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) - # for input and output of conv in the standalone module - count_check = { - ns.call_module(torch.quantization.MinMaxObserver): 2 - } - self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=count_check) + self.checkGraphModuleNodes(m, expected_node_occurrence=prepare_count_check) + self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_prepare_count_check) # check converted/quantized model m = convert_fx(m) - count_check = { - ns.call_function(torch.quantize_per_tensor) : 1, - ns.call_module(nnq.Conv2d) : 1, - ns.call_method('dequantize') : 1, - } - self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) - count_check = { - # standalone module will take float as input and output - # so we'll see quantize and dequantize in the modoule - ns.call_function(torch.quantize_per_tensor) : 1, - ns.call_module(nnq.Conv2d): 1, - ns.call_method('dequantize') : 1, - } - self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=count_check) + self.checkGraphModuleNodes(m, expected_node_occurrence=convert_count_check) + self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_convert_count_check) res = m(data) # quantize the reference model @@ -661,6 +657,76 @@ def forward(self, x): ref_res = ref_m(data) self.assertEqual(res, ref_res) + def test_standalone_module_float_interface(self): + float_interface_config = { + "input_quantized_idxs": [], # float input + "output_quantized_idxs": [], # float output + } + interface_config = float_interface_config + # input and output of first conv, observer for standalone module + # will be inserted in the standalone module itself + prepare_count_check = { + ns.call_module(torch.quantization.MinMaxObserver): 2 + } + # for input and output of conv in the standalone module + standalone_prepare_count_check = { + ns.call_module(torch.quantization.MinMaxObserver): 2 + } + convert_count_check = { + ns.call_function(torch.quantize_per_tensor) : 1, + ns.call_module(nnq.Conv2d) : 1, + ns.call_method("dequantize") : 1, + } + standalone_convert_count_check = { + # standalone module will take float as input and output + # so we'll see quantize and dequantize in the modoule + ns.call_function(torch.quantize_per_tensor) : 1, + ns.call_module(nnq.Conv2d): 1, + ns.call_method("dequantize") : 1, + } + self._test_standalone_module( + interface_config, + prepare_count_check, + standalone_prepare_count_check, + convert_count_check, + standalone_convert_count_check) + + def test_standalone_module_quantized_interface(self): + quantized_interface_config = { + "input_quantized_idxs": [0], # quantized input + "output_quantized_idxs": [0], # quantized output + } + interface_config = quantized_interface_config + # observer for input and output of first conv + prepare_count_check = { + ns.call_module(torch.quantization.MinMaxObserver): 2 + } + # for output of conv in the standalone module + standalone_prepare_count_check = { + ns.call_module(torch.quantization.MinMaxObserver): 1 + } + convert_count_check = { + # quantizing input for conv + ns.call_function(torch.quantize_per_tensor) : 1, + ns.call_module(nnq.Conv2d) : 1, + # dequantizing output of standalone module + ns.call_method("dequantize") : 1, + } + standalone_convert_count_check = { + # quantization of input happens in parent module + # quantization of output happens in the quantized conv module + ns.call_function(torch.quantize_per_tensor) : 0, + ns.call_module(nnq.Conv2d): 1, + # dequantization for output happens in parent module + ns.call_method("dequantize") : 0, + } + self._test_standalone_module( + interface_config, + prepare_count_check, + standalone_prepare_count_check, + convert_count_check, + standalone_convert_count_check) + @skipIfNoFBGEMM def test_qconfig_none(self): class M(torch.nn.Module): diff --git a/test/test_dataset.py b/test/test_dataset.py index 2caa1a248435..a72b87cca555 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -90,7 +90,7 @@ def _collate_fn(batch): y = next(ds_iter) self.assertEqual(x, torch.tensor(sum(y), dtype=torch.float)) - collate_ds_nolen = CollateIterableDataset(ds_nolen) + collate_ds_nolen = CollateIterableDataset(ds_nolen) # type: ignore with self.assertRaises(NotImplementedError): len(collate_ds_nolen) ds_nolen_iter = iter(ds_nolen) @@ -144,7 +144,7 @@ def test_sampler_dataset(self): arrs = range(10) ds = IterDatasetWithLen(arrs) # Default SequentialSampler - sampled_ds = SamplerIterableDataset(ds) + sampled_ds = SamplerIterableDataset(ds) # type: ignore self.assertEqual(len(sampled_ds), 10) i = 0 for x in sampled_ds: @@ -152,7 +152,7 @@ def test_sampler_dataset(self): i += 1 # RandomSampler - random_sampled_ds = SamplerIterableDataset(ds, sampler=RandomSampler, replacement=True) + random_sampled_ds = SamplerIterableDataset(ds, sampler=RandomSampler, replacement=True) # type: ignore # Requires `__len__` to build SamplerDataset ds_nolen = IterDatasetWithoutLen(arrs) diff --git a/test/test_sparse.py b/test/test_sparse.py index 4e982b8333d9..228c66aa403e 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -356,6 +356,11 @@ def test_to_sparse(self): sp, _, _ = self._gen_sparse(2, 10, [3, 3, 3]) self.assertRaises(RuntimeError, lambda: sp.to_sparse()) + def test_sparse_bool(self): + a = self.value_tensor([True, False]).to(torch.bool) + b = a.to_sparse().to_dense() + self.assertEqual(a, b) + def test_scalar(self): # tensor with value a = self.sparse_tensor(self.index_tensor([]).unsqueeze(1), 12.3, []) diff --git a/test/test_torch.py b/test/test_torch.py index 1f85ed2fff54..72fa853e2e7c 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -5689,7 +5689,8 @@ def test_storage_multigpu(self, devices): x = torch.tensor([], device=device) self.assertEqual(x.dtype, x.storage().dtype) - @dtypes(torch.float, torch.double, torch.half) + @dtypesIfCUDA(torch.float, torch.double, torch.half) + @dtypes(torch.float, torch.double) def test_multinomial(self, device, dtype): def make_prob_dist(shape, is_contiguous): if is_contiguous: diff --git a/tools/amd_build/build_amd.py b/tools/amd_build/build_amd.py index 026293a9281a..9d4fa54c93b3 100755 --- a/tools/amd_build/build_amd.py +++ b/tools/amd_build/build_amd.py @@ -131,6 +131,20 @@ def is_hip_clang(): sources.write(line) print("%s updated" % gloo_cmake_file) +gloo_cmake_file = "third_party/gloo/cmake/Modules/Findrccl.cmake" +if os.path.exists(gloo_cmake_file): + do_write = False + with open(gloo_cmake_file, "r") as sources: + lines = sources.readlines() + newlines = [line.replace('RCCL_LIBRARY', 'RCCL_LIBRARY_PATH') for line in lines] + if lines == newlines: + print("%s skipped" % gloo_cmake_file) + else: + with open(gloo_cmake_file, "w") as sources: + for line in newlines: + sources.write(line) + print("%s updated" % gloo_cmake_file) + hipify_python.hipify( project_directory=proj_dir, output_directory=out_dir, diff --git a/tools/codegen/gen.py b/tools/codegen/gen.py index 8f521e6651bc..4768670b6f26 100644 --- a/tools/codegen/gen.py +++ b/tools/codegen/gen.py @@ -435,6 +435,8 @@ def gen_one(f: NativeFunction) -> Optional[str]: # For an overview of what this template code looks like, see # https://github.com/pytorch/rfcs/pull/9 return f"""\ +namespace {{ + {self.gen_structured_class( f, k, class_name=class_name, @@ -448,6 +450,8 @@ def gen_one(f: NativeFunction) -> Optional[str]: {impl_call} return {ret_expr}; }} + +}} // anonymous namespace """ elif self.target is Target.REGISTRATION: @@ -540,9 +544,13 @@ def gen_unstructured(self, f: NativeFunction) -> Optional[str]: """ return f"""\ +namespace {{ + {returns_type} {name}({args_str}) {{ {cuda_guard}{return_kw}{impl_name}({args_exprs_str}); }} + +}} // anonymous namespace """ elif self.target is Target.REGISTRATION: diff --git a/torch/csrc/jit/codegen/cuda/kernel_cache.cpp b/torch/csrc/jit/codegen/cuda/kernel_cache.cpp index f1a0a634727a..5bddc510fe56 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_cache.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_cache.cpp @@ -209,7 +209,7 @@ InputsIdLookup::IdLookupReturn InputsIdLookup::lookupId( std::stringstream encoded_inputs; for (const auto& input : inputs) { if (input.isTensor()) { - auto input_tensor = input.toTensor(); + auto& input_tensor = input.toTensor(); encoded_inputs << ";"; auto sep = ""; diff --git a/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp b/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp index 4e76dc23e55d..4f4aa0d1536b 100644 --- a/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp +++ b/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp @@ -45,11 +45,17 @@ constexpr int so_suffix_len = 3; constexpr int cpp_suffix_len = 4; #endif +intptr_t run(const std::string& cmd); + static bool programExists(const std::string& program) { TemplateEnv env; env.s("program", program); std::string cmd = format(check_exists_string, env); +#ifdef _MSC_VER + return (run(cmd.c_str()) == 0); +#else return (system(cmd.c_str()) == 0); +#endif } #ifdef _MSC_VER diff --git a/torch/csrc/jit/frontend/tracer.cpp b/torch/csrc/jit/frontend/tracer.cpp index 1bab391bd393..0c88371399de 100644 --- a/torch/csrc/jit/frontend/tracer.cpp +++ b/torch/csrc/jit/frontend/tracer.cpp @@ -137,7 +137,7 @@ Value* TracingState::getValue(const IValue& var) { return graph->insertNode(dict_node)->output(); } if (var.isTensor()) { - auto ten = var.toTensor(); + auto& ten = var.toTensor(); if (!ten.defined()) { Node* n = graph->createNone(); return graph->insertNode(n)->output(); @@ -237,7 +237,7 @@ bool TracingState::hasValue(const IValue& var) const { Value* TracingState::getOutput(const IValue& iv, size_t i) { bool tracing_mode_strict = getTracingState()->strict; if (iv.isTensor()) { - at::Tensor var = iv.toTensor(); + const at::Tensor& var = iv.toTensor(); if (!var.defined()) { Node* n = graph->createNone(); return graph->insertNode(n)->output(); @@ -506,7 +506,7 @@ void setValueTrace(const IValue& v, Value* value) { } void TracingState::setValue(const IValue& v, Value* value) { if (v.isTensor()) { - auto var = v.toTensor(); + auto& var = v.toTensor(); AT_ASSERT(var.defined()); env_stack.back()[v] = value; } else if (v.isTensorList()) { diff --git a/torch/csrc/jit/passes/freeze_module.cpp b/torch/csrc/jit/passes/freeze_module.cpp index 2778c7712f23..f66f54eeb567 100644 --- a/torch/csrc/jit/passes/freeze_module.cpp +++ b/torch/csrc/jit/passes/freeze_module.cpp @@ -289,11 +289,11 @@ class AttributePropagator { IValue overrideGradient(IValue attr) { if (attr.isTensor()) { - auto t = attr.toTensor(); + auto& t = attr.toTensor(); if (t.requires_grad()) { - t = t.detach(); - t.set_requires_grad(false); - attr = IValue(t); + auto detached = t.detach(); + detached.set_requires_grad(false); + attr = IValue(std::move(detached)); } } else if (attr.isTuple()) { auto tuple = std::move(attr).toTuple(); diff --git a/torch/csrc/jit/runtime/argument_spec.h b/torch/csrc/jit/runtime/argument_spec.h index 401933c6d67e..a0e60e879146 100644 --- a/torch/csrc/jit/runtime/argument_spec.h +++ b/torch/csrc/jit/runtime/argument_spec.h @@ -237,7 +237,7 @@ struct CompleteArgumentSpec { for (int32_t i = 0; i < num_inputs; i++) { if (!inputs[i].isTensor()) continue; - auto tensor = inputs[i].toTensor(); + auto& tensor = inputs[i].toTensor(); all_dims += tensor.defined() ? tensor.ndimension() : 0; } // allocate enough room for all TensorPODs and dimensions diff --git a/torch/csrc/jit/runtime/interpreter.cpp b/torch/csrc/jit/runtime/interpreter.cpp index 24ca9dbf9793..ce4718becaf7 100644 --- a/torch/csrc/jit/runtime/interpreter.cpp +++ b/torch/csrc/jit/runtime/interpreter.cpp @@ -1418,7 +1418,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { // Check every input's shape against profiled (expected) shape. for (i = 0; i < num_inputs; i++) { auto& input = peek(stack, i, num_inputs); - auto t = input.toTensor(); + auto& t = input.toTensor(); const TypePtr& expected = frame.function->type_table_[inst.X + i]; auto expected_type = expected->cast(); if (t.defined() && !expected_type->matchTensor(t)) { @@ -1439,7 +1439,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { // so it's safe to pass this guard check push(stack, true); } else { - auto t = stack.back().toTensor(); + auto& t = stack.back().toTensor(); const TypePtr& expected = frame.function->type_table_[inst.X]; auto expected_type = expected->cast(); if (t.defined() && diff --git a/torch/csrc/jit/runtime/profiling_record.cpp b/torch/csrc/jit/runtime/profiling_record.cpp index 8d276dd58b50..d233f089f187 100644 --- a/torch/csrc/jit/runtime/profiling_record.cpp +++ b/torch/csrc/jit/runtime/profiling_record.cpp @@ -165,7 +165,7 @@ void ProfilingRecord::insertShapeProfile(Node* n, size_t offset) { if (v.isTensor()) { std::lock_guard lock(this->mutex_); auto& profiled_types = profiled_types_per_frame_[frame_id]; - auto t = v.toTensor(); + auto& t = v.toTensor(); if (t.defined()) { auto pttp = tensorTypeInCurrentExecutionContext(t); GRAPH_DEBUG( diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index 5c118f513565..89519d3765b5 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -79,13 +79,13 @@ struct static_add final : public at::native::structured_add_out { REGISTER_OPERATOR_FUNCTOR(aten::add, aten_add, [](Node* n) -> SROperator { return [](const ProcessedNode* p_node, std::vector& reg) { - auto in0_t = p_node->Input(0, reg).toTensor(); - auto in1_t = p_node->Input(1, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); + auto& in1_t = p_node->Input(1, reg).toTensor(); auto in2_s = p_node->Input(2, reg).toScalar(); if (p_node->Output(0, reg).isNone()) { p_node->Output(0, reg) = create_empty_from(in0_t); } - auto out_t = p_node->Output(0, reg).toTensor(); + auto& out_t = p_node->Output(0, reg).toTensor(); static_add op{out_t}; op.meta(in0_t, in1_t, in2_s); op.impl(in0_t, in1_t, in2_s, out_t); @@ -94,12 +94,12 @@ REGISTER_OPERATOR_FUNCTOR(aten::add, aten_add, [](Node* n) -> SROperator { REGISTER_OPERATOR_FUNCTOR(aten::mul, aten_mul, [](Node* n) -> SROperator { return [](const ProcessedNode* p_node, std::vector& reg) { - auto in0_t = p_node->Input(0, reg).toTensor(); - auto in1_t = p_node->Input(1, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); + auto& in1_t = p_node->Input(1, reg).toTensor(); if (p_node->Output(0, reg).isNone()) { p_node->Output(0, reg) = create_empty_from(in0_t); } - auto out_t = p_node->Output(0, reg).toTensor(); + auto& out_t = p_node->Output(0, reg).toTensor(); out_t.resize_({0}); at::native::mul_out(out_t, in0_t, in1_t); }; @@ -107,15 +107,15 @@ REGISTER_OPERATOR_FUNCTOR(aten::mul, aten_mul, [](Node* n) -> SROperator { REGISTER_OPERATOR_FUNCTOR(aten::addmm, aten_addmm, [](Node* n) -> SROperator { return [](const ProcessedNode* p_node, std::vector& reg) { - auto in0_t = p_node->Input(0, reg).toTensor(); - auto in1_t = p_node->Input(1, reg).toTensor(); - auto in2_t = p_node->Input(2, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); + auto& in1_t = p_node->Input(1, reg).toTensor(); + auto& in2_t = p_node->Input(2, reg).toTensor(); auto in3_s = p_node->Input(3, reg).toScalar(); auto in4_s = p_node->Input(4, reg).toScalar(); if (p_node->Output(0, reg).isNone()) { p_node->Output(0, reg) = create_empty_from(in0_t); } - auto out_t = p_node->Output(0, reg).toTensor(); + auto& out_t = p_node->Output(0, reg).toTensor(); out_t.resize_({0}); at::native::addmm_cpu_out(out_t, in0_t, in1_t, in2_t, in3_s, in4_s); }; @@ -123,13 +123,13 @@ REGISTER_OPERATOR_FUNCTOR(aten::addmm, aten_addmm, [](Node* n) -> SROperator { REGISTER_OPERATOR_FUNCTOR(aten::clamp, aten_clamp, [](Node* n) -> SROperator { return [](const ProcessedNode* p_node, std::vector& reg) { - auto in0_t = p_node->Input(0, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); auto in1_s = p_node->Input(1, reg).toScalar(); auto in2_s = p_node->Input(2, reg).toScalar(); if (p_node->Output(0, reg).isNone()) { p_node->Output(0, reg) = create_empty_from(in0_t); } - auto out_t = p_node->Output(0, reg).toTensor(); + auto& out_t = p_node->Output(0, reg).toTensor(); out_t.resize_({0}); at::native::clamp_out(out_t, in0_t, in1_s, in2_s); }; @@ -137,12 +137,12 @@ REGISTER_OPERATOR_FUNCTOR(aten::clamp, aten_clamp, [](Node* n) -> SROperator { REGISTER_OPERATOR_FUNCTOR(aten::bmm, aten_bmm, [](Node* n) -> SROperator { return [](const ProcessedNode* p_node, std::vector& reg) { - auto in0_t = p_node->Input(0, reg).toTensor(); - auto in1_t = p_node->Input(1, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); + auto& in1_t = p_node->Input(1, reg).toTensor(); if (p_node->Output(0, reg).isNone()) { p_node->Output(0, reg) = create_empty_from(in0_t); } - auto out_t = p_node->Output(0, reg).toTensor(); + auto& out_t = p_node->Output(0, reg).toTensor(); out_t.resize_({0}); at::native::bmm_out_cpu(out_t, in0_t, in1_t); }; @@ -154,7 +154,7 @@ REGISTER_OPERATOR_FUNCTOR( [](Node* n) -> SROperator { return [](const ProcessedNode* p_node, std::vector& reg) { auto input_size = p_node->input_regs().size(); - auto in0_t = p_node->Input(0, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); double in1_d = input_size > 1 ? p_node->Input(1, reg).toDouble() : 0; double in2_d = input_size > 2 ? p_node->Input(2, reg).toDouble() : std::numeric_limits::infinity(); @@ -164,7 +164,7 @@ REGISTER_OPERATOR_FUNCTOR( if (p_node->Output(0, reg).isNone()) { p_node->Output(0, reg) = create_empty_from(in0_t); } - auto out_t = p_node->Output(0, reg).toTensor(); + auto& out_t = p_node->Output(0, reg).toTensor(); out_t.resize_({0}); at::native::nan_to_num_out(out_t, in0_t, in1_d, in2_d, in3_d); }; @@ -176,18 +176,18 @@ REGISTER_OPERATOR_FUNCTOR(aten::cat, aten_cat, [](Node* n) -> SROperator { if (p_node->Output(0, reg).isNone()) { p_node->Output(0, reg) = create_empty_from(in0_tl[0]); } - auto out_t = p_node->Output(0, reg).toTensor(); + auto& out_t = p_node->Output(0, reg).toTensor(); out_t.resize_({0}); at::native::_cat_out_cpu(out_t, in0_tl, in1_i); }; }); REGISTER_OPERATOR_FUNCTOR(aten::tanh, aten_tanh, [](Node* n) -> SROperator { return [](const ProcessedNode* p_node, std::vector& reg) { - auto in0_t = p_node->Input(0, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); if (p_node->Output(0, reg).isNone()) { p_node->Output(0, reg) = create_empty_from(in0_t); } - auto out_t = p_node->Output(0, reg).toTensor(); + auto& out_t = p_node->Output(0, reg).toTensor(); out_t.resize_({0}); at::native::tanh_out(out_t, in0_t); }; @@ -217,7 +217,7 @@ SROperator aten_stack(Node* n) { for (auto i = 0; i < inputs.size(); i++) { inputs[i] = inputs[i].unsqueeze(dim); } - auto out_t = p_node->Output(0, reg).toTensor(); + auto& out_t = p_node->Output(0, reg).toTensor(); out_t.resize_({0}); at::native::_cat_out_cpu(out_t, inputs, dim); }; @@ -230,11 +230,11 @@ REGISTER_OPERATOR_FUNCTOR( aten_sigmoid, [](Node* n) -> SROperator { return [](const ProcessedNode* p_node, std::vector& reg) { - auto in0_t = p_node->Input(0, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); if (p_node->Output(0, reg).isNone()) { p_node->Output(0, reg) = create_empty_from(in0_t); } - auto out_t = p_node->Output(0, reg).toTensor(); + auto& out_t = p_node->Output(0, reg).toTensor(); out_t.resize_({0}); at::native::sigmoid_out(out_t, in0_t); }; @@ -247,57 +247,57 @@ REGISTER_OPERATOR_FUNCTOR( if (in1) { auto in1_s = in1->toScalar(); return [=](const ProcessedNode* p_node, std::vector& reg) { - auto in0_t = p_node->Input(0, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); if (p_node->Output(0, reg).isNone()) { p_node->Output(0, reg) = create_empty_from(in0_t); } - auto out_t = p_node->Output(0, reg).toTensor(); + auto& out_t = p_node->Output(0, reg).toTensor(); at::native::leaky_relu_out(out_t, in0_t, in1_s); }; } else { return [](const ProcessedNode* p_node, std::vector& reg) { - auto in0_t = p_node->Input(0, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); auto in1_s = p_node->Input(1, reg).toScalar(); if (p_node->Output(0, reg).isNone()) { p_node->Output(0, reg) = create_empty_from(in0_t); } - auto out_t = p_node->Output(0, reg).toTensor(); + auto& out_t = p_node->Output(0, reg).toTensor(); at::native::leaky_relu_out(out_t, in0_t, in1_s); }; } }); REGISTER_OPERATOR_FUNCTOR(aten::relu, aten_relu, [](Node* n) -> SROperator { return [](const ProcessedNode* p_node, std::vector& reg) { - auto in0_t = p_node->Input(0, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); if (p_node->Output(0, reg).isNone()) { p_node->Output(0, reg) = create_empty_from(in0_t); } - auto out_t = p_node->Output(0, reg).toTensor(); + auto& out_t = p_node->Output(0, reg).toTensor(); out_t.resize_({0}); at::native::threshold_out(out_t, in0_t, 0, 0); }; }); REGISTER_OPERATOR_FUNCTOR(aten::logit, aten_logit, [](Node* n) -> SROperator { return [](const ProcessedNode* p_node, std::vector& reg) { - auto in0_t = p_node->Input(0, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); double in1_d = p_node->input_regs().size() > 1 ? p_node->Input(1, reg).toDouble() : -1.0; if (p_node->Output(0, reg).isNone()) { p_node->Output(0, reg) = create_empty_from(in0_t); } - auto out_t = p_node->Output(0, reg).toTensor(); + auto& out_t = p_node->Output(0, reg).toTensor(); out_t.resize_({0}); at::native::logit_out(out_t, in0_t, in1_d); }; }); REGISTER_OPERATOR_FUNCTOR(aten::clone, aten_clone, [](Node* n) -> SROperator { return [](const ProcessedNode* p_node, std::vector& reg) { - auto in0_t = p_node->Input(0, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); if (p_node->Output(0, reg).isNone()) { p_node->Output(0, reg) = create_empty_from(in0_t); } - auto out_t = p_node->Output(0, reg).toTensor(); + auto& out_t = p_node->Output(0, reg).toTensor(); at::native::resize_as_(out_t, in0_t, c10::nullopt); at::native::copy_(out_t, in0_t, false); }; @@ -317,14 +317,14 @@ std::function&)> getNativeOperation(Node* n) { if (n->kind() == c10::Symbol::fromQualString("aten::transpose")) { return [](const ProcessedNode* p_node, std::vector& reg) { - auto in0_t = p_node->Input(0, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); auto in1_i = p_node->Input(1, reg).toInt(); auto in2_i = p_node->Input(2, reg).toInt(); p_node->Output(0, reg) = at::native::transpose(in0_t, in1_i, in2_i); }; } else if (n->kind() == c10::Symbol::fromQualString("aten::flatten")) { return [](const ProcessedNode* p_node, std::vector& reg) { - auto in0_t = p_node->Input(0, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); auto in1_i = p_node->Input(1, reg).toInt(); auto in2_i = p_node->Input(2, reg).toInt(); p_node->Output(0, reg) = at::native::flatten(in0_t, in1_i, in2_i); @@ -386,19 +386,19 @@ getNativeOperation(Node* n) { }; } else if (n->kind() == c10::Symbol::fromQualString("aten::permute")) { return [](const ProcessedNode* p_node, std::vector& reg) { - auto in0_t = p_node->Input(0, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); auto in1_iv = p_node->Input(1, reg).toIntVector(); p_node->Output(0, reg) = at::native::permute(in0_t, in1_iv); }; } else if (n->kind() == c10::Symbol::fromQualString("aten::reshape")) { return [](const ProcessedNode* p_node, std::vector& reg) { - auto in0_t = p_node->Input(0, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); auto in1_iv = p_node->Input(1, reg).toIntVector(); p_node->Output(0, reg) = at::native::reshape(in0_t, in1_iv); }; } else if (n->kind() == c10::Symbol::fromQualString("aten::slice")) { return [](const ProcessedNode* p_node, std::vector& reg) { - auto in0_t = p_node->Input(0, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); auto in1_i = p_node->Input(1, reg).toInt(); auto in2_i = p_node->Input(2, reg).toInt(); auto in3_i = p_node->Input(3, reg).toInt(); @@ -408,13 +408,13 @@ getNativeOperation(Node* n) { }; } else if (n->kind() == c10::Symbol::fromQualString("aten::narrow")) { return [](const ProcessedNode* p_node, std::vector& reg) { - auto self = p_node->Input(0, reg).toTensor(); // self + auto& self = p_node->Input(0, reg).toTensor(); // self auto dim = p_node->Input(1, reg).toInt(); // dim int64_t start = 0; if (p_node->Input(2, reg).isScalar()) { start = p_node->Input(2, reg).toInt(); } else { - auto t = p_node->Input(2, reg).toTensor(); + auto& t = p_node->Input(2, reg).toTensor(); start = t.item(); } auto length = p_node->Input(3, reg).toInt(); // length @@ -440,7 +440,7 @@ getNativeOperation(Node* n) { } else if (n->kind() == c10::Symbol::fromQualString("aten::to")) { return [](const ProcessedNode* p_node, std::vector& reg) { DCHECK(p_node->input_regs().size() == 5); - auto in0_t = p_node->Input(0, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); auto in1_i = p_node->Input(1, reg).toScalarType(); auto in2_i = p_node->Input(2, reg).toBool(); auto in3_i = p_node->Input(3, reg).toBool(); diff --git a/torch/csrc/jit/serialization/pickler.cpp b/torch/csrc/jit/serialization/pickler.cpp index 6e5c3b927c38..811569485888 100644 --- a/torch/csrc/jit/serialization/pickler.cpp +++ b/torch/csrc/jit/serialization/pickler.cpp @@ -354,7 +354,7 @@ void Pickler::pushLiteralTensor(const IValue& ivalue) { // // The format here is the same one used by `torch.save()`. The code for the // format can be found in `torch/serialization.py`. - auto tensor = ivalue.toTensor(); + auto& tensor = ivalue.toTensor(); bool quantized = tensor.is_quantized(); // The arguments to this function are: // storage, storage_offset, size, stride, requires_grad, backward_hooks diff --git a/torch/csrc/jit/serialization/python_print.cpp b/torch/csrc/jit/serialization/python_print.cpp index c86cbc460c9c..18d656c98f32 100644 --- a/torch/csrc/jit/serialization/python_print.cpp +++ b/torch/csrc/jit/serialization/python_print.cpp @@ -309,12 +309,12 @@ struct PythonPrintImpl { // because it doesn't hash any information about the tensors. // We will probably need to optimize this at some point using hashing. if (val.isTensor()) { - auto t = val.toTensor(); + auto& t = val.toTensor(); for (size_t i = 0; i < constant_table_.size(); ++i) { if (!constant_table_[i].isTensor()) { continue; } - auto t2 = constant_table_[i].toTensor(); + auto& t2 = constant_table_[i].toTensor(); if (t.options().type_equal(t2.options()) && t.equal(t2)) { return i; } diff --git a/torch/csrc/jit/serialization/unpickler.cpp b/torch/csrc/jit/serialization/unpickler.cpp index 3ff5da29fe1f..841e87592be9 100644 --- a/torch/csrc/jit/serialization/unpickler.cpp +++ b/torch/csrc/jit/serialization/unpickler.cpp @@ -632,7 +632,7 @@ void Unpickler::rebuildTensor(bool quantized) { auto tup = pop(stack_).toTuple(); const auto& elements = tup->elements(); size_t idx = 0; - auto storage_tensor = elements.at(idx++).toTensor(); + auto& storage_tensor = elements.at(idx++).toTensor(); int64_t storage_offset = elements.at(idx++).toInt(); std::vector size = tupleToIntList(elements.at(idx++)); std::vector stride = tupleToIntList(elements.at(idx++)); diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index 073c95c28619..837ecca6fe9d 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -922,9 +922,8 @@ def __setstate__(self, state): super(MultiheadAttention, self).__setstate__(state) - def forward(self, query, key, value, key_padding_mask=None, - need_weights=True, attn_mask=None): - # type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]] + def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: query, key, value: map a query and a set of key-value pairs to an output. diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index f22c35fa39ff..6a9c4dcd2ef6 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -530,8 +530,9 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, # dilation being an optional parameter is for backwards # compatibility - def _output_padding(self, input, output_size, stride, padding, kernel_size, dilation=None): - # type: (Tensor, Optional[List[int]], List[int], List[int], List[int], Optional[List[int]]) -> List[int] + def _output_padding(self, input: Tensor, output_size: Optional[List[int]], + stride: List[int], padding: List[int], kernel_size: List[int], + dilation: Optional[List[int]] = None) -> List[int]: if output_size is None: ret = _single(self.output_padding) # converting to list if was not already else: diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 297a4edf15bf..f054590da66a 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -843,7 +843,6 @@ def _slow_forward(self, *input, **kwargs): if recording_scopes: name = torch.jit._trace._trace_module_map[self] if self in torch.jit._trace._trace_module_map else None if name: - cur_scope_name = tracing_state.current_scope() tracing_state.push_scope(name) else: recording_scopes = False diff --git a/torch/nn/modules/utils.py b/torch/nn/modules/utils.py index 3e0b93c7afc0..97e4195619cb 100644 --- a/torch/nn/modules/utils.py +++ b/torch/nn/modules/utils.py @@ -26,8 +26,7 @@ def _reverse_repeat_tuple(t, n): return tuple(x for x in reversed(t) for _ in range(n)) -def _list_with_default(out_size, defaults): - # type: (List[int], List[int]) -> List[int] +def _list_with_default(out_size: List[int], defaults: List[int]) -> List[int]: if isinstance(out_size, int): return out_size if len(defaults) <= len(out_size): diff --git a/torch/nn/parallel/replicate.py b/torch/nn/parallel/replicate.py index a069c6c6f939..8effeece5908 100644 --- a/torch/nn/parallel/replicate.py +++ b/torch/nn/parallel/replicate.py @@ -108,7 +108,6 @@ def replicate(network, devices, detach=False): modules = list(network.modules()) module_copies = [[] for device in devices] module_indices = {} - scriptmodule_skip_attr = {"_parameters", "_buffers", "_modules", "forward", "_c"} for i, module in enumerate(modules): module_indices[module] = i diff --git a/torch/nn/quantized/dynamic/modules/rnn.py b/torch/nn/quantized/dynamic/modules/rnn.py index df88169471ca..59c0195d7858 100644 --- a/torch/nn/quantized/dynamic/modules/rnn.py +++ b/torch/nn/quantized/dynamic/modules/rnn.py @@ -239,8 +239,6 @@ def from_float(cls, mod): _all_weight_values = [] for layer in range(qRNNBase.num_layers): for direction in range(num_directions): - layer_input_size = qRNNBase.input_size if layer == 0 else qRNNBase.hidden_size * num_directions - suffix = '_reverse' if direction == 1 else '' def retrieve_weight_bias(ihhh): diff --git a/torch/nn/quantized/modules/embedding_ops.py b/torch/nn/quantized/modules/embedding_ops.py index d16748b3baf7..e41d55347741 100644 --- a/torch/nn/quantized/modules/embedding_ops.py +++ b/torch/nn/quantized/modules/embedding_ops.py @@ -52,7 +52,6 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - version = local_metadata.get('version', None) self.dtype = state_dict[prefix + 'dtype'] state_dict.pop(prefix + 'dtype') diff --git a/torch/nn/quantized/modules/normalization.py b/torch/nn/quantized/modules/normalization.py index 4664120ec8b5..c12f74374863 100644 --- a/torch/nn/quantized/modules/normalization.py +++ b/torch/nn/quantized/modules/normalization.py @@ -29,7 +29,6 @@ def _get_name(self): @classmethod def from_float(cls, mod): - activation_post_process = mod.activation_post_process scale, zero_point = mod.activation_post_process.calculate_qparams() new_mod = cls( mod.normalized_shape, mod.weight, mod.bias, float(scale), @@ -63,7 +62,6 @@ def _get_name(self): @classmethod def from_float(cls, mod): - activation_post_process = mod.activation_post_process scale, zero_point = mod.activation_post_process.calculate_qparams() new_mod = cls( mod.num_groups, mod.num_channels, mod.weight, mod.bias, float(scale), int(zero_point), @@ -98,7 +96,6 @@ def _get_name(self): @classmethod def from_float(cls, mod): - activation_post_process = mod.activation_post_process scale, zero_point = mod.activation_post_process.calculate_qparams() new_mod = cls( mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point), @@ -133,7 +130,6 @@ def _get_name(self): @classmethod def from_float(cls, mod): - activation_post_process = mod.activation_post_process scale, zero_point = mod.activation_post_process.calculate_qparams() new_mod = cls( mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point), @@ -168,7 +164,6 @@ def _get_name(self): @classmethod def from_float(cls, mod): - activation_post_process = mod.activation_post_process scale, zero_point = mod.activation_post_process.calculate_qparams() new_mod = cls( mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point), diff --git a/torch/nn/utils/prune.py b/torch/nn/utils/prune.py index 84fa30021ed1..851a551da0d8 100644 --- a/torch/nn/utils/prune.py +++ b/torch/nn/utils/prune.py @@ -587,7 +587,6 @@ def compute_mask(self, t, default_mask): # Compute number of units to prune: amount if int, # else amount * tensor_size nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size) - nparams_tokeep = tensor_size - nparams_toprune # This should raise an error if the number of units to prune is larger # than the number of units in the tensor _validate_pruning_amount(nparams_toprune, tensor_size) diff --git a/torch/quantization/fx/fuse.py b/torch/quantization/fx/fuse.py index 5aabbd66c4b1..59e3851dcd57 100644 --- a/torch/quantization/fx/fuse.py +++ b/torch/quantization/fx/fuse.py @@ -21,7 +21,7 @@ from .quantization_types import Pattern -from typing import Callable, Tuple, Optional +from typing import Callable, Tuple class Fuser: @@ -59,11 +59,12 @@ def load_arg(a): model = GraphModule(input_root, self.fused_graph) return model - def _find_matches(self, root: GraphModule, graph: Graph, - patterns: Dict[Pattern, Callable] - ) -> Dict[str, Tuple[Node, Optional[Any]]]: + def _find_matches( + self, root: GraphModule, graph: Graph, + patterns: Dict[Pattern, Callable] + ) -> Dict[str, Tuple[Node, FuseHandler]]: modules = dict(root.named_modules()) - match_map = {} # node name -> (root_node, match_value?) + match_map : Dict[str, Tuple[Node, FuseHandler]] = {} # node name -> (root_node, match_value) def apply_match(pattern, node, match): if isinstance(pattern, tuple): diff --git a/torch/quantization/fx/fusion_patterns.py b/torch/quantization/fx/fusion_patterns.py index b7af6008b3f3..1749484fccec 100644 --- a/torch/quantization/fx/fusion_patterns.py +++ b/torch/quantization/fx/fusion_patterns.py @@ -6,12 +6,25 @@ from .utils import _parent_name from .quantization_types import QuantizerCls from ..fuser_method_mappings import get_fuser_method +from abc import ABC, abstractmethod from typing import Any, Callable, Dict # --------------------- -# Fusion Patterns +# Fusion Pattern Registrations # --------------------- +# Base Pattern Handler +class FuseHandler(ABC): + """ Base handler class for the fusion patterns + """ + def __init__(self, quantizer: QuantizerCls, node: Node): + pass + + @abstractmethod + def fuse(self, quantizer: QuantizerCls, load_arg: Callable, + fuse_custom_config_dict: Dict[str, Any] = None) -> Node: + pass + @register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv1d)) @register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv2d)) @register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv3d)) @@ -27,9 +40,9 @@ @register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm1d, torch.nn.Conv1d))) @register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm2d, torch.nn.Conv2d))) @register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm3d, torch.nn.Conv3d))) -class ConvBNReLUFusion(): +class ConvBNReLUFusion(FuseHandler): def __init__(self, quantizer: QuantizerCls, node: Node): - super().__init__() + super().__init__(quantizer, node) self.relu_node = None self.bn_node = None if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \ @@ -94,9 +107,9 @@ def fuse(self, quantizer: QuantizerCls, load_arg: Callable, @register_fusion_pattern((torch.nn.ReLU, torch.nn.BatchNorm2d)) @register_fusion_pattern((torch.nn.functional.relu, torch.nn.BatchNorm3d)) @register_fusion_pattern((torch.nn.ReLU, torch.nn.BatchNorm3d)) -class ModuleReLUFusion(): +class ModuleReLUFusion(FuseHandler): def __init__(self, quantizer: QuantizerCls, node: Node): - super().__init__() + super().__init__(quantizer, node) self.relu_node = node assert isinstance(node.args[0], Node) node = node.args[0] diff --git a/torch/quantization/fx/observed_module.py b/torch/quantization/fx/observed_module.py index a95bc184fa10..808a3b36fb4a 100644 --- a/torch/quantization/fx/observed_module.py +++ b/torch/quantization/fx/observed_module.py @@ -2,11 +2,11 @@ import copy from torch.fx import GraphModule # type: ignore from torch.fx.graph import Graph -from typing import Union, Dict, Any +from typing import Union, Dict, Any, List class ObservedGraphModule(GraphModule): - def get_preserved_attr_names(self): + def get_preserved_attr_names(self) -> List[str]: return ['_activation_post_process_map', '_patterns', '_qconfig_map', @@ -35,6 +35,12 @@ def is_observed_module(module: Any) -> bool: return isinstance(module, ObservedGraphModule) class ObservedStandaloneGraphModule(ObservedGraphModule): + def get_preserved_attr_names(self) -> List[str] : + return super().get_preserved_attr_names() + [ + "_standalone_module_input_quantized_idxs", + "_standalone_module_output_quantized_idxs" + ] + def __deepcopy__(self, memo): fake_mod = torch.nn.Module() fake_mod.__dict__ = copy.deepcopy(self.__dict__) diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index 46fbed74bdc8..fb5bef0bd0ad 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -755,10 +755,10 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, qconfig = quantizer.qconfig_map[node.name] convert = torch.quantization.quantize_fx._convert_standalone_module_fx # type: ignore observed_standalone_module = quantizer.modules[node.target] + input_quantized_idxs = observed_standalone_module._standalone_module_input_quantized_idxs.tolist() quantized_standalone_module = convert(observed_standalone_module, debug=debug) parent_name, name = _parent_name(node.target) # update the modules dict setattr(quantizer.modules[parent_name], name, quantized_standalone_module) quantizer.modules[node.target] = quantized_standalone_module - # standalone module takes float input - return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False)) + return quantizer.quantized_graph.node_copy(node, load_arg(quantized=input_quantized_idxs)) diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index af9496a66a63..318295270b61 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -102,14 +102,15 @@ def insert_observer( 'call_module', observer_name, (load_arg(node),), {}) observed_node_names_set.add(node.name) -def insert_observer_for_special_module( +def maybe_insert_observer_for_special_module( quantize_handler: QuantizeHandler, modules: Dict[str, torch.nn.Module], - prepare_custom_config_dict: Any, qconfig: Any, node: Node): + prepare_custom_config_dict: Any, qconfig: Any, node: Node) -> Optional[List[int]]: """ Insert observer for custom module and standalone module Returns: standalone_module_input_idxs: the indexs for inputs that needs to be observed by parent module """ assert modules is not None + standalone_module_input_idxs = None if isinstance(quantize_handler, CustomModuleQuantizeHandler): custom_module = modules[node.target] # type: ignore custom_module_class_mapping = prepare_custom_config_dict.get( @@ -129,19 +130,22 @@ def insert_observer_for_special_module( class_config_map = {x[0]: (x[1], x[2]) for x in standalone_module_class_configs} name_config_map = {x[0]: (x[1], x[2]) for x in standalone_module_name_configs} config = class_config_map.get(type(standalone_module), (None, None)) - config = name_config_map.get(node.target, (None, None)) - standalone_module_qconfig_dict = {"": qconfig} if config[0] is None else config[0] - standalone_prepare_config_dict = {} if config[1] is None else config[1] + config = name_config_map.get(node.target, config) + sm_qconfig_dict = {"": qconfig} if config[0] is None else config[0] + sm_prepare_config_dict = {} if config[1] is None else config[1] prepare = \ torch.quantization.quantize_fx._prepare_standalone_module_fx # type: ignore observed_standalone_module = \ - prepare(standalone_module, standalone_module_qconfig_dict, standalone_prepare_config_dict) + prepare(standalone_module, sm_qconfig_dict, sm_prepare_config_dict) + standalone_module_input_idxs = observed_standalone_module.\ + _standalone_module_input_quantized_idxs.int().tolist() observed_standalone_module = mark_observed_standalone_module( observed_standalone_module) parent_name, name = _parent_name(node.target) setattr(modules[parent_name], name, observed_standalone_module) modules[node.target] = observed_standalone_module # type: ignore + return standalone_module_input_idxs def insert_observer_for_output_of_the_node( node: Node, @@ -155,7 +159,8 @@ def insert_observer_for_output_of_the_node( observed_graph: Graph, load_arg: Callable, observed_node_names_set: Set[str], - matched_nodes: Optional[List[Node]]): + matched_nodes: Optional[List[Node]], + standalone_module_input_idxs: Optional[List[int]]): """ Insert observer/fake_quantize module for output of the observed module if needed """ @@ -215,8 +220,13 @@ def input_is_observed(arg): observed_node_names_set.add(node.name) elif isinstance(quantize_handler, StandaloneModuleQuantizeHandler): - # output is observed in the standalone module - return + assert node.op == "call_module" + assert isinstance(node.target, str) + sm_out_qidxs = modules[node.target]._standalone_module_output_quantized_idxs.tolist() # type: ignore + output_is_quantized = 0 in sm_out_qidxs + + if output_is_quantized: + observed_node_names_set.add(node.name) elif (quantize_handler.all_node_args and input_output_observed(quantize_handler)): # observer for outputs @@ -226,6 +236,16 @@ def input_is_observed(arg): activation_post_process_map, env, observed_graph, load_arg, observed_node_names_set) + # insert observer for input of standalone module + if standalone_module_input_idxs is not None: + for idx in standalone_module_input_idxs: + if node.args[idx].name not in observed_node_names_set: # type: ignore + new_observer = qconfig.activation() + insert_observer( + node, new_observer, model, + activation_post_process_map, env, observed_graph, + load_arg, observed_node_names_set) + def insert_observer_for_input_arg_of_observed_node( node: Node, observed_node_names_set: Set[str], quants: Dict[str, Tuple[DefaultQuantizeHandler, Callable]], @@ -373,10 +393,19 @@ def _prepare(self, model: GraphModule, qconfig_dict: Any, """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. - When we are preparing a standalone module: - both input and output are observed in prepared standalone module + How the standalone module is observed is specified by `input_quantized_idxs` and + `output_quantized_idxs` in the prepare_custom_config for the standalone module Returns: model(GraphModule): prepared standalone module + attributes: + _standalone_module_input_quantized_idxs(List[Int]): a list of + indexes for the graph input that is expected to be quantized, + same as input_quantized_idxs configuration provided + for the standalone module + _standalone_module_output_quantized_idxs(List[Int]): a list of + indexs for the graph output that is quantized + same as input_quantized_idxs configuration provided + for the standalone module """ if prepare_custom_config_dict is None: prepare_custom_config_dict = {} @@ -430,8 +459,6 @@ def _prepare(self, model: GraphModule, qconfig_dict: Any, def load_arg(a): return map_arg(a, lambda node: env[node.name]) - # indexes for the inputs that needs to be observed - standalone_module_observed_input_idxs: List[int] = [] graph_inputs = [] for node in model.graph.nodes: if node.op == 'placeholder': @@ -487,14 +514,15 @@ def load_arg(a): # parent if qconfig is not None: assert obj is not None - insert_observer_for_special_module( - obj, self.modules, prepare_custom_config_dict, qconfig, - node) + standalone_module_input_idxs = \ + maybe_insert_observer_for_special_module( + obj, self.modules, prepare_custom_config_dict, qconfig, + node) insert_observer_for_output_of_the_node( node, obj, qconfig, self.modules, model, pattern, self.activation_post_process_map, env, observed_graph, load_arg, observed_node_names_set, - matched_nodes) + matched_nodes, standalone_module_input_idxs) else: env[node.name] = observed_graph.node_copy(node, load_arg) @@ -516,6 +544,21 @@ def load_arg(a): model = GraphModule(model, observed_graph) self.save_state(model) model = mark_observed_module(model) + if is_standalone_module: + assert result_node is not None + assert isinstance(result_node.args[0], Node), \ + "standalone module only supports returning simple value currently"\ + "(not tuple, dict etc.)" + # indicator for whether output is observed or not. + # This used for correctly quantize standalone modules + output_is_observed = \ + result_node.args[0].name in observed_node_names_set + # these inputs are observed in parent + # converting List[int] to Tensor since module attribute is + # Union[Tensor, Module] + model._standalone_module_input_quantized_idxs = \ + torch.Tensor(input_quantized_idxs) + model._standalone_module_output_quantized_idxs = torch.Tensor(output_quantized_idxs) return model def save_state(self, observed: GraphModule) -> None: @@ -569,8 +612,10 @@ def _convert(self, model: GraphModule, debug: bool = False, """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. - Returns a quantized standalone module which accepts float input - and produces float output. + Returns a quantized standalone module, whether input/output is quantized is + specified by prepare_custom_config_dict, with + input_quantized_idxs, output_quantized_idxs, please + see docs for prepare_fx for details """ if convert_custom_config_dict is None: convert_custom_config_dict = {} @@ -627,36 +672,50 @@ def load_x(n: Node) -> Node: else: return env[n.name] - def load_arg(quantized: Optional[Union[List[Any], bool, Tuple[Any, ...]]] + def load_arg(quantized: Optional[Union[List[int], bool, Tuple[int, ...]]] ) -> Callable[[Node], Argument]: """ Input: quantized, which can be None, list, boolean or tuple - - if quantized is a list or tuple, then arg should be a list and - the args with corresponding indexes will be quantized - - if quantized is a boolean, then all args will be - quantized/not quantized - if quantized is None, then we'll load the node as long as it exists + - if quantized is a boolean, then all args will be + quantized/not quantized + - if quantized is an empty list or tuple, then it is the same as load_arg(quantized=False) + - if quantized is a list or tuple, then arg should be a list and + the args with corresponding indexes will be quantized Output: fn which takes arg_or_args, and loads them from the corresponding environment depending on the value of quantized. """ assert quantized is None or \ isinstance(quantized, (tuple, list, bool)), type(quantized) + if isinstance(quantized, (tuple, list)) and len(quantized) == 0: + # empty tuple or list means nothing is quantized + quantized = False def load_arg_impl(arg_or_args): - if quantized is None: + # we'll update the format of `quantized` + # to better match arg_or_args + updated_quantized: Optional[Union[List[int], bool, Tuple[int, ...]]] = quantized + + if isinstance(quantized, (tuple, list)) and \ + len(quantized) == 1 and isinstance(arg_or_args, Node): + # when argument is one Node instead of tuple, we just need to check + # 0 is in the quantized list + updated_quantized = 0 in quantized + + if updated_quantized is None: return map_arg(arg_or_args, load_x) - if isinstance(quantized, bool): + if isinstance(updated_quantized, bool): return map_arg( arg_or_args, - load_quantized if quantized else load_non_quantized) - elif isinstance(quantized, (tuple, list)): + load_quantized if updated_quantized else load_non_quantized) + elif isinstance(updated_quantized, (tuple, list)): assert isinstance(arg_or_args, (tuple, list)), arg_or_args loaded_args = [] # for now, we only support quantizing positional arguments for i, a in enumerate(arg_or_args): - if i in quantized: + if i in updated_quantized: loaded_args.append(map_arg(a, load_quantized)) else: loaded_args.append(map_arg(a, load_non_quantized)) @@ -690,10 +749,10 @@ def node_arg_is_quantized(node_arg: Any) -> bool: def is_output_quantized(node: Node, obj: QuantizeHandler) -> bool: """ Check if output node is quantized or not """ assert self.modules is not None - # by default the output is expected to be quantized + # by default the output for a quantizable node is expected to be quantized quantized = True - # Need to get correct quantized/non-quantized state for the output + # Need to get correct quantized/non-quantized state forn the output # of CopyNode if type(obj) in [ CopyNode, @@ -750,7 +809,7 @@ def insert_quantize_node(node: Node) -> None: "output_quantized_idxs", []) for node in model.graph.nodes: - if node.op == 'output': + if node.op == "output": cur_output_node_idx = output_node_seen_cnt output_node_seen_cnt += 1 if cur_output_node_idx in output_quantized_idxs: @@ -775,12 +834,19 @@ def insert_quantize_node(node: Node) -> None: quantized = False else: assert obj is not None + # We will get whether the output is quantized or not before + # convert for standalone module and after convert + # for non-standalone module, since _standalone_module_output_quantized_idxs + # is only available in observed standalone module + if is_observed_standalone_module_node: + out_quant_idxs = self.modules[node.target]._standalone_module_output_quantized_idxs.tolist() # type: ignore + assert len(out_quant_idxs) <= 1, "Currently standalone only support one output" + quantized = 0 in out_quant_idxs + result = obj.convert( self, node, load_arg, debug=debug, convert_custom_config_dict=convert_custom_config_dict) - if is_observed_standalone_module_node: - quantized = False - else: + if not is_observed_standalone_module_node: quantized = is_output_quantized(node, obj) if quantized: @@ -929,7 +995,7 @@ def _find_matches( standalone_module_names = [] match_map: Dict[str, MatchResult] = {} - all_matched = set() + all_matched : Set[str] = set() def record_match(pattern, node, matched): if isinstance(pattern, tuple): diff --git a/torch/quantization/fx/utils.py b/torch/quantization/fx/utils.py index c1f849803342..8285e204b1ed 100644 --- a/torch/quantization/fx/utils.py +++ b/torch/quantization/fx/utils.py @@ -9,7 +9,7 @@ Node, ) -from typing import Callable, Optional, List, Dict, Any +from typing import Callable, Optional, List, Dict, Any, Set # turn foo.bar -> ['foo', 'bar'] def _parent_name(target): @@ -140,7 +140,7 @@ def get_next_qparams_idx(module, qparams): inputs.append(graph.create_node('get_attr', qparam_full_path)) return graph.create_node('call_function', quantize_op, tuple(inputs), {}) -def get_custom_module_class_keys(custom_config_dict, custom_config_dict_key): +def get_custom_module_class_keys(custom_config_dict, custom_config_dict_key) -> List[Any]: r""" Get all the unique custom module keys in the custom config dict e.g. Input: @@ -163,7 +163,7 @@ def get_custom_module_class_keys(custom_config_dict, custom_config_dict_key): [CustomModule1, CustomModule2, CustomModule3] """ # using set to dedup - float_custom_module_classes = set() + float_custom_module_classes : Set[Any] = set() custom_module_mapping = custom_config_dict.get(custom_config_dict_key, {}) for quant_mode in ["static", "dynamic", "weight_only"]: quant_mode_custom_module_config = custom_module_mapping.get(quant_mode, {}) diff --git a/torch/quantization/quantize_fx.py b/torch/quantization/quantize_fx.py index cba104b8f783..89ba877ffe78 100644 --- a/torch/quantization/quantize_fx.py +++ b/torch/quantization/quantize_fx.py @@ -107,8 +107,20 @@ def _prepare_standalone_module_fx( standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. - Both input and output of the module are observed in the - standalone module. + How the standalone module is observed is specified by `input_quantized_idxs` and + `output_quantized_idxs` in the prepare_custom_config for the standalone module + + Returns: + model(GraphModule): prepared standalone module + attributes: + _standalone_module_input_quantized_idxs(List[Int]): a list of + indexes for the graph input that is expected to be quantized, + same as input_quantized_idxs configuration provided + for the standalone module + _standalone_module_output_quantized_idxs(List[Int]): a list of + indexs for the graph output that is quantized + same as input_quantized_idxs configuration provided + for the standalone module """ return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, is_standalone_module=True) @@ -378,8 +390,9 @@ def _convert_standalone_module_fx( r""" [Internal use only] Convert a model produced by :func:`~torch.quantization.prepare_standalone_module_fx` and convert it to a quantized model - Return: - A quantized standalone module which accepts float input - and produces float output. + Returns a quantized standalone module, whether input/output is quantized is + specified by prepare_custom_config_dict, with + input_quantized_idxs, output_quantized_idxs, please + see docs for prepare_fx for details """ return _convert_fx(graph_module, debug, convert_custom_config_dict, is_standalone_module=True)