From d9c3485146913324ab4b3e211d2a4517e138f4af Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 20 May 2024 07:40:41 +0000 Subject: [PATCH 01/35] Revert "c10d: add Collectives abstraction (#125978)" This reverts commit 4b2ae2ac338f3a0de340c9711b03989b8ce66fc6. Reverted https://github.com/pytorch/pytorch/pull/125978 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/125978#issuecomment-2119858015)) --- BUILD.bazel | 2 +- build_variables.bzl | 1 - test/distributed/test_control_collectives.py | 189 ----------- torch/_C/_distributed_c10d.pyi | 14 - torch/csrc/distributed/c10d/HashStore.hpp | 2 +- torch/csrc/distributed/c10d/Store.hpp | 29 -- .../ControlCollectives.hpp | 59 ---- .../control_collectives/StoreCollectives.cpp | 222 ------------- .../control_collectives/StoreCollectives.hpp | 68 ---- torch/csrc/distributed/c10d/init.cpp | 302 +++--------------- torch/distributed/__init__.py | 2 - 11 files changed, 53 insertions(+), 837 deletions(-) delete mode 100644 test/distributed/test_control_collectives.py delete mode 100644 torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp delete mode 100644 torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp delete mode 100644 torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp diff --git a/BUILD.bazel b/BUILD.bazel index 831d64b44c2f6..3f7e6327452c0 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -772,7 +772,7 @@ cc_library( [ "torch/*.h", "torch/csrc/**/*.h", - "torch/csrc/distributed/c10d/**/*.hpp", + "torch/csrc/distributed/c10d/*.hpp", "torch/lib/libshm/*.h", ], exclude = [ diff --git a/build_variables.bzl b/build_variables.bzl index 152324a4d90cb..3f16f9b847c1c 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -487,7 +487,6 @@ libtorch_core_sources = sorted( # These files are the only ones that are supported on Windows. libtorch_distributed_base_sources = [ "torch/csrc/distributed/c10d/Backend.cpp", - "torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp", "torch/csrc/distributed/c10d/FileStore.cpp", "torch/csrc/distributed/c10d/Functional.cpp", "torch/csrc/distributed/c10d/GlooDeviceFactory.cpp", diff --git a/test/distributed/test_control_collectives.py b/test/distributed/test_control_collectives.py deleted file mode 100644 index fb0067f2dd2e9..0000000000000 --- a/test/distributed/test_control_collectives.py +++ /dev/null @@ -1,189 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -from datetime import timedelta -from multiprocessing.pool import ThreadPool - -import torch -import torch.distributed as dist -from torch.testing._internal.common_utils import run_tests, TestCase - - -class TestCollectives(TestCase): - def test_barrier(self) -> None: - store = dist.HashStore() - - world_size = 2 - - def f(rank: int) -> None: - collectives = dist._StoreCollectives(store, rank, world_size) - collectives.barrier("foo", timedelta(seconds=10), True) - - with ThreadPool(world_size) as pool: - pool.map(f, range(world_size)) - - def test_broadcast(self) -> None: - store = dist.HashStore() - - world_size = 4 - timeout = timedelta(seconds=10) - - def f(rank: int) -> None: - collectives = dist._StoreCollectives(store, rank, world_size) - if rank == 2: - collectives.broadcast_send("foo", b"data", timeout) - else: - out = collectives.broadcast_recv("foo", timeout) - self.assertEqual(out, b"data") - - with ThreadPool(world_size) as pool: - pool.map(f, range(world_size)) - - def test_gather(self) -> None: - store = dist.HashStore() - - world_size = 4 - timeout = timedelta(seconds=10) - - def f(rank: int) -> None: - collectives = dist._StoreCollectives(store, rank, world_size) - if rank == 2: - out = collectives.gather_recv("foo", str(rank), timeout) - self.assertEqual(out, [b"0", b"1", b"2", b"3"]) - else: - collectives.gather_send("foo", str(rank), timeout) - - with ThreadPool(world_size) as pool: - pool.map(f, range(world_size)) - - def test_scatter(self) -> None: - store = dist.HashStore() - - world_size = 4 - timeout = timedelta(seconds=10) - - def f(rank: int) -> None: - collectives = dist._StoreCollectives(store, rank, world_size) - if rank == 2: - out = collectives.scatter_send( - "foo", [str(i) for i in range(world_size)], timeout - ) - else: - out = collectives.scatter_recv("foo", timeout) - self.assertEqual(out, str(rank).encode()) - - with ThreadPool(world_size) as pool: - pool.map(f, range(world_size)) - - def test_all_sum(self) -> None: - store = dist.HashStore() - - world_size = 4 - timeout = timedelta(seconds=10) - - def f(rank: int) -> None: - collectives = dist._StoreCollectives(store, rank, world_size) - out = collectives.all_sum("foo", rank, timeout) - self.assertEqual(out, sum(range(world_size))) - - with ThreadPool(world_size) as pool: - pool.map(f, range(world_size)) - - def test_broadcast_timeout(self) -> None: - store = dist.HashStore() - - world_size = 4 - timeout = timedelta(milliseconds=1) - collectives = dist._StoreCollectives(store, 1, world_size) - with self.assertRaisesRegex(Exception, "Wait timeout"): - collectives.broadcast_recv("foo", timeout) - - def test_gather_timeout(self) -> None: - store = dist.HashStore() - - world_size = 4 - timeout = timedelta(milliseconds=1) - collectives = dist._StoreCollectives(store, 1, world_size) - with self.assertRaisesRegex( - Exception, "gather failed -- missing ranks: 0, 2, 3" - ): - collectives.gather_recv("foo", "data", timeout) - - def test_scatter_timeout(self) -> None: - store = dist.HashStore() - - world_size = 4 - timeout = timedelta(milliseconds=1) - collectives = dist._StoreCollectives(store, 1, world_size) - with self.assertRaisesRegex(Exception, "Wait timeout"): - collectives.scatter_recv("foo", timeout) - - def test_all_gather_timeout(self) -> None: - store = dist.HashStore() - - world_size = 4 - timeout = timedelta(milliseconds=1) - collectives = dist._StoreCollectives(store, 1, world_size) - with self.assertRaisesRegex( - Exception, "all_gather failed -- missing ranks: 0, 2, 3" - ): - collectives.all_gather("foo", "data", timeout) - - def test_barrier_timeout(self) -> None: - store = dist.HashStore() - - world_size = 4 - timeout = timedelta(milliseconds=1) - collectives = dist._StoreCollectives(store, 1, world_size) - with self.assertRaisesRegex( - Exception, "barrier failed -- missing ranks: 0, 2, 3" - ): - collectives.barrier("foo", timeout, True) - - def test_all_sum_timeout(self) -> None: - store = dist.HashStore() - - world_size = 4 - timeout = timedelta(milliseconds=1) - collectives = dist._StoreCollectives(store, 1, world_size) - with self.assertRaisesRegex( - Exception, "barrier failed -- missing ranks: 0, 2, 3" - ): - collectives.all_sum("foo", 1, timeout) - - def test_unique(self) -> None: - store = dist.HashStore() - - collectives = dist._StoreCollectives(store, 1, 1) - collectives.broadcast_send("foo", "bar") - - with self.assertRaisesRegex(Exception, "Key foo has already been used"): - collectives.broadcast_send("foo", "bar") - - with self.assertRaisesRegex(Exception, "Key foo has already been used"): - collectives.broadcast_recv("foo") - - with self.assertRaisesRegex(Exception, "Key foo has already been used"): - collectives.gather_send("foo", "bar") - - with self.assertRaisesRegex(Exception, "Key foo has already been used"): - collectives.gather_recv("foo", "asdf") - - with self.assertRaisesRegex(Exception, "Key foo has already been used"): - collectives.scatter_send("foo", ["asdf"]) - - with self.assertRaisesRegex(Exception, "Key foo has already been used"): - collectives.scatter_recv("foo") - - with self.assertRaisesRegex(Exception, "Key foo has already been used"): - collectives.all_gather("foo", "bar") - - with self.assertRaisesRegex(Exception, "Key foo has already been used"): - collectives.all_sum("foo", 2) - - -if __name__ == "__main__": - assert ( - not torch.cuda._initialized - ), "test_distributed must not have initialized CUDA context on main process" - - run_tests() diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 74a73a3ddaa46..28d790e3d6903 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -210,20 +210,6 @@ class PrefixStore(Store): @property def underlying_store(self) -> Store: ... -class _ControlCollectives: - def barrier(self, key: str, timeout: timedelta, blocking: bool) -> None: ... - def broadcast_send(self, key: str, data: str, timeout: timedelta) -> None: ... - def broadcast_recv(self, key: str, timeout: timedelta) -> str: ... - def gather_send(self, key: str, data: str, timeout: timedelta) -> None: ... - def gather_recv(self, key: str, timeout: timedelta) -> str: ... - def scatter_send(self, key: str, data: str, timeout: timedelta) -> None: ... - def scatter_recv(self, key: str, timeout: timedelta) -> str: ... - def all_gather(self, key: str, data: str, timeout: timedelta) -> str: ... - def all_sum(self, key: str, data: str, timeout: timedelta) -> int: ... - -class _StoreCollectives(_ControlCollectives): - def __init__(self, store: Store, rank: int, world_size: int) -> None: ... - class _DistributedBackendOptions: def __init__(self): ... @property diff --git a/torch/csrc/distributed/c10d/HashStore.hpp b/torch/csrc/distributed/c10d/HashStore.hpp index 3697d62301ba3..1453c0a72808a 100644 --- a/torch/csrc/distributed/c10d/HashStore.hpp +++ b/torch/csrc/distributed/c10d/HashStore.hpp @@ -22,7 +22,7 @@ class TORCH_API HashStore : public Store { std::vector get(const std::string& key) override; void wait(const std::vector& keys) override { - wait(keys, timeout_); + wait(keys, Store::kDefaultTimeout); } void wait( diff --git a/torch/csrc/distributed/c10d/Store.hpp b/torch/csrc/distributed/c10d/Store.hpp index 993284fa7cc56..af715ba98a794 100644 --- a/torch/csrc/distributed/c10d/Store.hpp +++ b/torch/csrc/distributed/c10d/Store.hpp @@ -97,33 +97,4 @@ class TORCH_API Store : public torch::CustomClassHolder { std::chrono::milliseconds timeout_; }; -/* -StoreTimeoutGuard is a RAII guard that will set the store timeout and restore it -when it returns. -*/ -class StoreTimeoutGuard { - public: - explicit StoreTimeoutGuard( - Store& store, - const std::chrono::milliseconds& timeout) - : store_(store) { - oldTimeout_ = store.getTimeout(); - store.setTimeout(timeout); - } - - ~StoreTimeoutGuard() { - store_.setTimeout(oldTimeout_); - } - - /* Disabling copy and move semantics */ - StoreTimeoutGuard(const StoreTimeoutGuard&) = delete; - StoreTimeoutGuard& operator=(const StoreTimeoutGuard&) = delete; - StoreTimeoutGuard(StoreTimeoutGuard&&) = delete; - StoreTimeoutGuard& operator=(StoreTimeoutGuard&&) = delete; - - private: - Store& store_; - std::chrono::milliseconds oldTimeout_; -}; - } // namespace c10d diff --git a/torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp b/torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp deleted file mode 100644 index b98f9a71fb024..0000000000000 --- a/torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp +++ /dev/null @@ -1,59 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include - -#include -#include - -namespace c10d { - -using namespace std::chrono_literals; - -class TORCH_API ControlCollectives : public torch::CustomClassHolder { - public: - virtual void barrier( - const std::string& key, - std::chrono::milliseconds timeout = 5min, - bool block = true) = 0; - - virtual void broadcastSend( - const std::string& key, - const std::vector& data, - std::chrono::milliseconds timeout = 5min) = 0; - virtual std::vector broadcastRecv( - const std::string& key, - std::chrono::milliseconds timeout = 5min) = 0; - - virtual void gatherSend( - const std::string& key, - const std::vector& data, - std::chrono::milliseconds timeout = 5min) = 0; - virtual std::vector> gatherRecv( - const std::string& key, - const std::vector& data, - std::chrono::milliseconds timeout = 5min) = 0; - - virtual std::vector scatterSend( - const std::string& key, - const std::vector>& data, - std::chrono::milliseconds timeout = 5min) = 0; - virtual std::vector scatterRecv( - const std::string& key, - std::chrono::milliseconds timeout = 5min) = 0; - - virtual std::vector> allGather( - const std::string& key, - const std::vector& data, - std::chrono::milliseconds timeout = 5min) = 0; - - virtual int64_t allSum( - const std::string& key, - int64_t data, - std::chrono::milliseconds timeout = 5min) = 0; -}; - -} // namespace c10d diff --git a/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp b/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp deleted file mode 100644 index 995899441d461..0000000000000 --- a/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp +++ /dev/null @@ -1,222 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -namespace { -std::string getRankKey(const std::string& key, int rank) { - return fmt::format("{}/{}", key, rank); -} -} // namespace - -namespace c10d { - -StoreCollectives::StoreCollectives( - c10::intrusive_ptr<::c10d::Store> store, - int rank, - int worldSize) - : store_(std::move(store)), rank_(rank), worldSize_(worldSize) {} - -void StoreCollectives::barrier( - const std::string& key, - std::chrono::milliseconds timeout, - bool blocking) { - enforceUnique(key); - StoreTimeoutGuard g{*store_, timeout}; - - auto num_members_key = fmt::format("{}/num_members", key); - auto last_members_key = fmt::format("{}/last_members", key); - - auto idx = store_->add(num_members_key, 1); - store_->set(getRankKey(key, rank_), "joined"); - - if (idx == worldSize_) { - store_->set(last_members_key, ""); - } else if (blocking) { - try { - store_->wait({last_members_key}); - } catch (const std::exception& e) { - std::string msg = "barrier failed -- missing ranks: "; - for (int i = 0; i < worldSize_; i++) { - if (i == rank_) { - continue; - } - auto rank_key = getRankKey(key, i); - if (!store_->check({rank_key})) { - msg += fmt::format("{}, ", i); - } - } - throw std::runtime_error(msg + e.what()); - } - } -} - -void StoreCollectives::broadcastSend( - const std::string& key, - const std::vector& data, - std::chrono::milliseconds timeout) { - enforceUnique(key); - StoreTimeoutGuard g{*store_, timeout}; - - store_->set(key, data); -} - -std::vector StoreCollectives::broadcastRecv( - const std::string& key, - std::chrono::milliseconds timeout) { - enforceUnique(key); - StoreTimeoutGuard g{*store_, timeout}; - - return store_->get(key); -} - -void StoreCollectives::gatherSend( - const std::string& key, - const std::vector& data, - std::chrono::milliseconds timeout) { - enforceUnique(key); - StoreTimeoutGuard g{*store_, timeout}; - - auto rank_key = getRankKey(key, rank_); - store_->set(rank_key, data); -} - -std::vector> StoreCollectives::gatherRecv( - const std::string& key, - const std::vector& data, - std::chrono::milliseconds timeout) { - enforceUnique(key); - StoreTimeoutGuard g{*store_, timeout}; - - std::vector keys; - keys.reserve(worldSize_); - - for (int i = 0; i < worldSize_; i++) { - if (i == rank_) { - continue; - } - auto rank_key = getRankKey(key, i); - keys.emplace_back(rank_key); - } - - std::vector> results; - results.reserve(worldSize_); - - try { - results = store_->multiGet(keys); - } catch (const std::exception& e) { - std::string msg = "gather failed -- missing ranks: "; - for (int i = 0; i < worldSize_; i++) { - if (i == rank_) { - continue; - } - auto rank_key = getRankKey(key, i); - if (!store_->check({rank_key})) { - msg += fmt::format("{}, ", i); - } - } - throw std::runtime_error(msg + e.what()); - } - - // insert local data - results.insert(results.begin() + rank_, data); - return results; -} - -std::vector StoreCollectives::scatterSend( - const std::string& key, - const std::vector>& data, - std::chrono::milliseconds timeout) { - enforceUnique(key); - StoreTimeoutGuard g{*store_, timeout}; - - std::vector keys; - keys.reserve(worldSize_); - for (int i = 0; i < worldSize_; i++) { - if (i == rank_) { - continue; - } - auto rank_key = getRankKey(key, i); - keys.emplace_back(rank_key); - } - auto local = data.at(rank_); - - std::vector> toSend{data}; - - toSend.erase(toSend.begin() + rank_); - - store_->multiSet(keys, toSend); - - return local; -} - -std::vector StoreCollectives::scatterRecv( - const std::string& key, - std::chrono::milliseconds timeout) { - enforceUnique(key); - StoreTimeoutGuard g{*store_, timeout}; - - auto rank_key = getRankKey(key, rank_); - return store_->get(rank_key); -} - -std::vector> StoreCollectives::allGather( - const std::string& key, - const std::vector& data, - std::chrono::milliseconds timeout) { - enforceUnique(key); - StoreTimeoutGuard g{*store_, timeout}; - - auto localKey = getRankKey(key, rank_); - store_->set(localKey, data); - - std::vector keys; - keys.reserve(worldSize_); - - for (int i = 0; i < worldSize_; i++) { - auto rank_key = getRankKey(key, i); - keys.emplace_back(rank_key); - } - - try { - return store_->multiGet(keys); - } catch (const std::exception& e) { - std::string msg = "all_gather failed -- missing ranks: "; - for (int i = 0; i < worldSize_; i++) { - if (i == rank_) { - continue; - } - auto rank_key = getRankKey(key, i); - if (!store_->check({rank_key})) { - msg += fmt::format("{}, ", i); - } - } - throw std::runtime_error(msg + e.what()); - } -} - -int64_t StoreCollectives::allSum( - const std::string& key, - int64_t value, - std::chrono::milliseconds timeout) { - enforceUnique(key); - StoreTimeoutGuard g{*store_, timeout}; - - store_->add(key, value); - - barrier(key + "/barrier", timeout); - - return store_->add(key, 0); -} - -void StoreCollectives::enforceUnique(const std::string& key) { - auto it = seenKeys_.find(key); - TORCH_INTERNAL_ASSERT( - it == seenKeys_.end(), "Key ", key, " has already been used."); - seenKeys_.emplace(key); -} - -} // namespace c10d diff --git a/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp b/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp deleted file mode 100644 index 7d3eb5038565e..0000000000000 --- a/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp +++ /dev/null @@ -1,68 +0,0 @@ -#pragma once - -#include -#include -#include -#include - -namespace c10d { - -class TORCH_API StoreCollectives : public ControlCollectives { - public: - explicit StoreCollectives( - c10::intrusive_ptr store, - int rank, - int worldSize); - - void barrier( - const std::string& key, - std::chrono::milliseconds timeout = 5min, - bool block = true) override; - - void broadcastSend( - const std::string& key, - const std::vector& data, - std::chrono::milliseconds timeout = 5min) override; - std::vector broadcastRecv( - const std::string& key, - std::chrono::milliseconds timeout = 5min) override; - - void gatherSend( - const std::string& key, - const std::vector& data, - std::chrono::milliseconds timeout = 5min) override; - std::vector> gatherRecv( - const std::string& key, - const std::vector& data, - std::chrono::milliseconds timeout = 5min) override; - - std::vector scatterSend( - const std::string& key, - const std::vector>& data, - std::chrono::milliseconds timeout = 5min) override; - std::vector scatterRecv( - const std::string& key, - std::chrono::milliseconds timeout = 5min) override; - - std::vector> allGather( - const std::string& key, - const std::vector& data, - std::chrono::milliseconds timeout = 5min) override; - - int64_t allSum( - const std::string& key, - int64_t data, - std::chrono::milliseconds timeout = 5min) override; - - private: - void enforceUnique(const std::string& key); - - private: - c10::intrusive_ptr store_; - int rank_; - int worldSize_; - - c10::FastSet seenKeys_{}; -}; - -} // namespace c10d diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 505b64e2a6976..483becbce0094 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -6,9 +6,6 @@ #include #include #include -#include -#include -#include #ifndef _WIN32 #include #include @@ -139,34 +136,6 @@ namespace torch::distributed::c10d { namespace { -py::bytes toPyBytes(const std::vector& data) { - return py::bytes(reinterpret_cast(data.data()), data.size()); -} - -std::vector toPyBytes( - const std::vector>& data) { - std::vector out; - out.reserve(data.size()); - for (const std::vector& data_ : data) { - out.emplace_back(reinterpret_cast(data_.data()), data_.size()); - } - return out; -} - -std::vector toVec8(const std::string& data) { - std::vector out{data.begin(), data.end()}; - return out; -} - -std::vector> toVec8(const std::vector& data) { - std::vector> out; - out.reserve(data.size()); - for (auto& data_ : data) { - out.emplace_back(toVec8(data_)); - } - return out; -} - template using shared_ptr_class_ = py::class_>; @@ -197,7 +166,8 @@ class PythonStore : public ::c10d::Store { pybind11::get_overload(static_cast(this), "set"); TORCH_INTERNAL_ASSERT(fn, "Not implemented."); // Call function with a py::bytes object for the value. - fn(key, toPyBytes(value)); + fn(key, + py::bytes(reinterpret_cast(value.data()), value.size())); } // Note: this function manually calls the Python-side overload @@ -214,7 +184,7 @@ class PythonStore : public ::c10d::Store { // std::vector. There is no API for directly accessing // the contents of the py::bytes object. std::string str = pybind11::cast(fn(key)); - return toVec8(str); + return std::vector(str.begin(), str.end()); } // Note: this function manually calls the Python-side overload @@ -234,8 +204,14 @@ class PythonStore : public ::c10d::Store { // std::vector. There is no API for directly accessing // the contents of the py::bytes object. std::string str = pybind11::cast( - fn(key, toPyBytes(expectedValue), toPyBytes(desiredValue))); - return toVec8(str); + fn(key, + py::bytes( + reinterpret_cast(expectedValue.data()), + expectedValue.size()), + py::bytes( + reinterpret_cast(desiredValue.data()), + desiredValue.size()))); + return std::vector(str.begin(), str.end()); } int64_t add(const std::string& key, int64_t value) override { @@ -277,7 +253,8 @@ class PythonStore : public ::c10d::Store { return Store::append(key, value); } // Call function with a py::bytes object for the value. - fn(key, toPyBytes(value)); + fn(key, + py::bytes(reinterpret_cast(value.data()), value.size())); } std::vector> multiGet( @@ -310,7 +287,14 @@ class PythonStore : public ::c10d::Store { return Store::multiSet(keys, values); } - fn(keys, toPyBytes(values)); + std::vector bytes; + bytes.reserve(values.size()); + for (auto& value : values) { + bytes.emplace_back( + reinterpret_cast(value.data()), value.size()); + } + + fn(keys, bytes); } bool hasExtendedApi() const override { @@ -989,7 +973,10 @@ and :class:`~torch.distributed.HashStore`). "set", [](::c10d::Store& store, const std::string& key, - const std::string& value) { store.set(key, toVec8(value)); }, + const std::string& value) { + std::vector value_(value.begin(), value.end()); + store.set(key, value_); + }, py::call_guard(), R"( Inserts the key-value pair into the store based on the supplied ``key`` and @@ -1014,9 +1001,14 @@ Example:: const std::string& key, const std::string& expected_value, const std::string& desired_value) -> py::bytes { - auto value = store.compareSet( - key, toVec8(expected_value), toVec8(desired_value)); - return toPyBytes(value); + std::vector expectedValue_( + expected_value.begin(), expected_value.end()); + std::vector desiredValue_( + desired_value.begin(), desired_value.end()); + auto value = + store.compareSet(key, expectedValue_, desiredValue_); + return py::bytes( + reinterpret_cast(value.data()), value.size()); }, py::call_guard(), R"( @@ -1048,7 +1040,8 @@ Example:: py::gil_scoped_release guard; return store.get(key); }(); - return toPyBytes(value); + return py::bytes( + reinterpret_cast(value.data()), value.size()); }, R"( Retrieves the value associated with the given ``key`` in the store. If ``key`` is not @@ -1247,7 +1240,8 @@ Example:: [](::c10d::Store& store, const std::string& key, const std::string& value) { - store.append(key, toVec8(value)); + std::vector value_(value.begin(), value.end()); + store.append(key, value_); }, py::call_guard(), R"( @@ -1274,7 +1268,14 @@ Example:: py::gil_scoped_release guard; return store.multiGet(keys); }(); - return toPyBytes(values); + std::vector res; + for (auto& value : values) { + auto bytes = py::bytes( + reinterpret_cast(value.data()), + value.size()); + res.push_back(bytes); + } + return res; }, R"( Retrieve all values in ``keys``. If any key in ``keys`` is not @@ -1297,7 +1298,12 @@ Example:: [](::c10d::Store& store, const std::vector& keys, const std::vector& values) { - store.multiSet(keys, toVec8(values)); + std::vector> vals; + vals.reserve(values.size()); + for (auto& value : values) { + vals.emplace_back(value.begin(), value.end()); + } + store.multiSet(keys, vals); }, py::call_guard(), R"( @@ -1481,212 +1487,6 @@ that adds a prefix to each key inserted to the store. &::c10d::PrefixStore::getUnderlyingNonPrefixStore, R"(Recursively to get the store before layers of wrapping with PrefixStore.)"); - using namespace std::chrono_literals; - - auto collectives = - py::class_< - ::c10d::ControlCollectives, - c10::intrusive_ptr<::c10d::ControlCollectives>>( - module, - "_ControlCollectives", - R"( -Base class for all ControlCollectives implementations. -)") - .def( - "barrier", - &::c10d::ControlCollectives::barrier, - py::arg("key"), - py::arg("timeout") = 5min, - py::arg("block") = true, - py::call_guard(), - R"( -Blocks until all workers have entered this function. - -Arguments: - key (str): The unique key used to identify this operation. - timeout (duration): The timeout for this operation. - block (bool): whether to block this working waiting on the results of the barrier. -)") - .def( - "all_sum", - &::c10d::ControlCollectives::allSum, - py::arg("key"), - py::arg("data"), - py::arg("timeout") = 5min, - py::call_guard(), - R"( -Computes a sum across all workers and returns the final value. - -Arguments: - key (str): The unique key used to identify this operation. - data (int): The data to sum. - timeout (duration): The timeout for this operation. -)") - .def( - "broadcast_send", - [](::c10d::ControlCollectives& collectives, - const std::string& key, - const std::string& data, - std::chrono::milliseconds timeout = 5min) { - collectives.broadcastSend(key, toVec8(data), timeout); - }, - py::arg("key"), - py::arg("data"), - py::arg("timeout") = 5min, - py::call_guard(), - R"( -Sends data to all other workers. Must be only called from one worker. - -Arguments: - key (str): The unique key used to identify this operation. - data (str): The data to send. - timeout (duration): The timeout for this operation. -)") - .def( - "broadcast_recv", - [](::c10d::ControlCollectives& collectives, - const std::string& key, - std::chrono::milliseconds timeout = 5min) { - auto out = [&]() { - py::gil_scoped_release guard; - return collectives.broadcastRecv(key, timeout); - }(); - return toPyBytes(out); - }, - py::arg("key"), - py::arg("timeout") = 5min, - R"( -Receives data broadcasted from 1 worker. - -Arguments: - key (str): The unique key used to identify this operation. - timeout (duration): The timeout for this operation. -)") - .def( - "gather_send", - [](::c10d::ControlCollectives& collectives, - const std::string& key, - const std::string& data, - std::chrono::milliseconds timeout = 5min) { - collectives.gatherSend(key, toVec8(data), timeout); - }, - py::arg("key"), - py::arg("data"), - py::arg("timeout") = 5min, - py::call_guard(), - R"( -Sends data to one other worker. - -Arguments: - key (str): The unique key used to identify this operation. - data (str): The data to send. - timeout (duration): The timeout for this operation. -)") - .def( - "gather_recv", - [](::c10d::ControlCollectives& collectives, - const std::string& key, - const std::string& data, - std::chrono::milliseconds timeout = 5min) { - auto out = [&]() { - py::gil_scoped_release guard; - return collectives.gatherRecv(key, toVec8(data), timeout); - }(); - return toPyBytes(out); - }, - py::arg("key"), - py::arg("data"), - py::arg("timeout") = 5min, - R"( -Receives data broadcasted from all workers. Must only be called by one worker. - -Arguments: - key (str): The unique key used to identify this operation. - timeout (duration): The timeout for this operation. -)") - - .def( - "scatter_send", - [](::c10d::ControlCollectives& collectives, - const std::string& key, - const std::vector& data, - std::chrono::milliseconds timeout = 5min) { - auto out = [&]() { - py::gil_scoped_release guard; - return collectives.scatterSend(key, toVec8(data), timeout); - }(); - return toPyBytes(out); - }, - py::arg("key"), - py::arg("data"), - py::arg("timeout") = 5min, - R"( -Sends rank specific data to all other workers. - -Arguments: - key (str): The unique key used to identify this operation. - data (str): The data to send. - timeout (duration): The timeout for this operation. -)") - .def( - "scatter_recv", - [](::c10d::ControlCollectives& collectives, - const std::string& key, - std::chrono::milliseconds timeout = 5min) { - auto out = [&]() { - py::gil_scoped_release guard; - return collectives.scatterRecv(key, timeout); - }(); - return toPyBytes(out); - }, - py::arg("key"), - py::arg("timeout") = 5min, - R"( -Receives rank specific data from one worker. - -Arguments: - key (str): The unique key used to identify this operation. - timeout (duration): The timeout for this operation. -)") - - .def( - "all_gather", - [](::c10d::ControlCollectives& collectives, - const std::string& key, - const std::string& data, - std::chrono::milliseconds timeout = 5min) { - auto out = [&]() { - py::gil_scoped_release guard; - return collectives.allGather(key, toVec8(data), timeout); - }(); - return toPyBytes(out); - }, - py::arg("key"), - py::arg("data"), - py::arg("timeout") = 5min, - R"( -Sends data to all workers and receives data from all other workers. - -Arguments: - key (str): The unique key used to identify this operation. - data (str): The data to send. - timeout (duration): The timeout for this operation. -)"); - - intrusive_ptr_class_<::c10d::StoreCollectives>( - module, - "_StoreCollectives", - collectives, - R"( -An implementation of ControlCollectives that uses the provided store as the underlying -communication mechanism. - )") - .def( - py::init, int, int>(), - py::arg("store"), - py::arg("rank"), - py::arg("world_size")); - auto processGroup = py::class_< ::c10d::ProcessGroup, diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py index 3e7dce97b54c9..eb7a690fa9589 100644 --- a/torch/distributed/__init__.py +++ b/torch/distributed/__init__.py @@ -54,8 +54,6 @@ def is_available() -> bool: set_debug_level, set_debug_level_from_env, _make_nccl_premul_sum, - _ControlCollectives, - _StoreCollectives, ) class _DistributedPdb(pdb.Pdb): From 7100a729502f463a14e195e7b3ec04d7103fce99 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Sun, 19 May 2024 03:48:00 +0100 Subject: [PATCH 02/35] [inductor] Fix ops.scan for non-commutative operators (#126633) `tl.associative_scan` supports non-commutative combine functions but `tl.reduce` doesn't. This effects non-persistent scans, where we use the reduction from the previous loop iterations as the base for future iterations. Here I work around this by taking the last element of the scan output and using that as the reduced value. This is done using a trick where we create a mask that is 1 at the desired element and 0 elsewhere, then sum over it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126633 Approved by: https://github.com/Chillee, https://github.com/lezcano --- test/inductor/test_cuda_repro.py | 26 ++++++++++++++++++++++++++ torch/_inductor/codegen/triton.py | 24 ++++++++++++++++-------- torch/_inductor/lowering.py | 2 +- 3 files changed, 43 insertions(+), 9 deletions(-) diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index f303330bc1140..da3869c5a3acc 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -1135,6 +1135,32 @@ def fn(arg207_1, arg208_1, convert_element_type_40, expand, full, mul_3): fn(*args) torch.cuda.synchronize() # shake out Triton Error [CUDA]: misaligned address + def test_non_commutative_scan_op(self): + from torch._higher_order_ops.associative_scan import associative_scan + + a = torch.randn(1024, 8192, dtype=torch.float64, device="cuda") + b = torch.randn(1024, 8192, dtype=torch.float64, device="cuda") + + def baseline(v, u): + A = [] + A.append(b[:, 0]) + for i in range(1, v.shape[1]): + A.append(a[:, i] * A[i - 1] + b[:, i]) + return torch.stack(A, dim=1) + + def combine_fn(i, j): + ia, ib = i + ja, jb = j + return ia * ja, ib * ja + jb + + @torch.compile + def compiled_scan(a, b): + return associative_scan(combine_fn, (a, b), dim=-1)[1] + + out1 = baseline(a, b) + out2 = compiled_scan(a, b) + self.assertEqual(out1, out2) + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 183d28605b87a..eddb7bdcdcd14 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -13,7 +13,6 @@ import torch import torch._logging -import torch.utils._pytree as pytree from torch._dynamo.utils import preserve_rng_state from torch._inductor.runtime.hints import AutotuneHint, DeviceProperties @@ -1650,13 +1649,22 @@ def cse_multiple(line, n, masks): ) if not self.persistent_reduction: - partial_reduce_vars = pytree.tree_map( - self.reduction_resize, - cse_multiple( - f"tl.reduce(({csv(broadcasted_values)}), {dim}, {combine_helper_fn})", - len(values), - None, - ), + + def sum_fn(a, b): + return [ops.add(ai, bi) for ai, bi in zip(a, b)] + + sum_helper_fn = self._lift_helper(sum_fn, len(values)) + pre_reduce_vars = ", ".join( + f"{scan_var} * (rbase == (RBLOCK - 1))" + for scan_var in partial_scan_vars + ) + # tl.reduce doesn't work for non-commutative operators, so instead + # of repeating the scan op as a reduction, we use sum to select the + # last scan value + partial_reduce_vars = cse_multiple( + f"tl.reduce(({pre_reduce_vars}), -1, {sum_helper_fn}, keep_dims=True)", + len(values), + masks, ) accs_next = combine_fn(tuple(accumulators), partial_reduce_vars) full_scan_vars = combine_fn(tuple(accumulators), partial_scan_vars) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 07899fe2ccd09..adf8c542d33e6 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -5919,7 +5919,7 @@ def wrapped_combine_fn(lhs, rhs): kwargs["dtypes"] = tuple(x.get_dtype() for x in input) kwargs["inner_fns"] = tuple(x.make_loader() for x in input) result = ir.Scan.create(**kwargs, combine_fn=wrapped_combine_fn) - if result is None: + if result[0] is None: raise RuntimeError("Unable to generate code for associative_scan op") return result From cb69c51b6f1d1803c8c2a01d26e3b3474917cf39 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 20 May 2024 12:14:22 +0000 Subject: [PATCH 03/35] Revert " Updated test_graph_optims and test_graph_scaling_fused_optimizers to use new OptimizerInfo infrastructure (#125127)" This reverts commit cf35a591b95220aa1bfcc04ff8a943efd1d6d6eb. Reverted https://github.com/pytorch/pytorch/pull/125127 on behalf of https://github.com/DanilBaibak due to Broken trunk ([comment](https://github.com/pytorch/pytorch/pull/125127#issuecomment-2120337584)) --- test/inductor/test_compiled_optimizers.py | 2 - test/test_cuda.py | 393 +++++++++++-------- torch/testing/_internal/common_optimizers.py | 64 +-- 3 files changed, 234 insertions(+), 225 deletions(-) diff --git a/test/inductor/test_compiled_optimizers.py b/test/inductor/test_compiled_optimizers.py index 7100837e9b92f..b2d0ed91809f9 100644 --- a/test/inductor/test_compiled_optimizers.py +++ b/test/inductor/test_compiled_optimizers.py @@ -136,8 +136,6 @@ class KernelCounts(NamedTuple): "test_sgd_momentum_foreach_cuda": 5, "test_sgd_weight_decay_maximize_cuda": 4, "test_sgd_weight_decay_maximize_cpu": 4, - "test_sgd_weight_decay_cpu": 4, - "test_sgd_weight_decay_cuda": 4, "test_sgd_momentum_weight_decay_foreach_cuda": 2, "test_sgd_momentum_nesterov_weight_decay_foreach_cuda": 2, "test_sgd_cuda": 4, diff --git a/test/test_cuda.py b/test/test_cuda.py index cc3e2380f2664..93e08eff4df6d 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -37,11 +37,7 @@ instantiate_device_type_tests, onlyCUDA, ) -from torch.testing._internal.common_optimizers import ( - _get_optim_inputs_including_global_cliquey_kwargs, - optim_db, - optims, -) +from torch.testing._internal.common_optimizers import optim_db, optims from torch.testing._internal.common_utils import ( freeze_rng_state, gcIfJetson, @@ -3204,6 +3200,111 @@ def _test_graphed_optimizer( for p_control, p_graphed in zip(params_control, params_graphed): self.assertEqual(p_control, p_graphed) + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + ) + def test_graph_optims(self): + # Needs generalization if we want to extend this test to non-Adam-like optimizers. + cases = ( + [ + ( + optimizer_ctor, + { + "lr": 0.1, + "betas": (0.8, 0.7), + "foreach": foreach, + "decoupled_weight_decay": decoupled_weight_decay, + "weight_decay": weight_decay, + }, + ) + for optimizer_ctor, foreach, decoupled_weight_decay, weight_decay in product( + ( + torch.optim.NAdam, + torch.optim.RAdam, + ), + ( + False, + True, + ), + ( + False, + True, + ), + ( + 0.0, + 0.1, + ), + ) + ] + + [ + ( + torch.optim.Rprop, + {"lr": 0.1, "foreach": foreach, "maximize": maximize}, + ) + for foreach, maximize in product( + ( + False, + True, + ), + ( + False, + True, + ), + ) + ] + + [ + ( + optimizer_ctor, + { + "lr": 0.1, + "betas": (0.8, 0.7), + "foreach": foreach, + "amsgrad": amsgrad, + }, + ) + for optimizer_ctor, foreach, amsgrad in product( + (torch.optim.Adam, torch.optim.AdamW), + (False, True), + (False, True), + ) + ] + + [ + ( + optimizer_ctor, + {"lr": 0.1, "betas": (0.8, 0.7), "fused": True, "amsgrad": amsgrad}, + ) + for optimizer_ctor, amsgrad in product( + (torch.optim.Adam, torch.optim.AdamW), (False, True) + ) + ] + + [ + ( + optimizer_ctor, + { + "lr": 0.1, + "foreach": foreach, + "maximize": maximize, + "weight_decay": weight_decay, + }, + ) + for optimizer_ctor, foreach, maximize, weight_decay in product( + ( + torch.optim.Adamax, + torch.optim.ASGD, + torch.optim.Adadelta, + torch.optim.RMSprop, + ), + (False, True), + (False, True), + (0, 0.1), + ) + ] + ) + + for optimizer_ctor, kwargs in cases: + with self.subTest(optimizer_ctor=optimizer_ctor, kwargs=kwargs): + self._test_graphed_optimizer(3, 2, optimizer_ctor, kwargs) + @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) @@ -3275,6 +3376,123 @@ def test_graph_optims_with_explicitly_capturable_param_groups(self): self.assertEqual(ref_p1, param1) self.assertEqual(ref_p2, param2) + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + ) + def test_graph_scaling_fused_optimizers(self): + cases = [ + ( + optimizer_ctor, + {"lr": 0.1, "betas": (0.8, 0.7), "fused": True, "amsgrad": amsgrad}, + ) + for optimizer_ctor, amsgrad in product( + (torch.optim.Adam, torch.optim.AdamW), (False, True) + ) + ] + list( + product( + (torch.optim.SGD,), + [ + { + "lr": 0.1, + "momentum": 0.0, + "dampening": d, + "weight_decay": w, + "nesterov": n, + "fused": True, + } + for d, w, n in product((0.0, 0.5), (0.0, 0.5), (False,)) + ] + + [ + { + "lr": 0.1, + "momentum": 0.5, + "dampening": d, + "weight_decay": w, + "nesterov": n, + "fused": True, + } + for d, w, n in product((0.0,), (0.0, 0.5), (True, False)) + ], + ) + ) + + steps_warmup = 3 + steps_train = 2 + + for OptClass, kwargs in cases: + has_capturable_arg = OptClass in (torch.optim.Adam, torch.optim.AdamW) + for actually_do_graphs in (True, False) if has_capturable_arg else (True,): + params = [torch.randn((i + 5, i + 5), device="cuda") for i in range(2)] + params_control = [p.clone().requires_grad_() for p in params] + params_graphed = [p.clone().requires_grad_() for p in params] + + # `GradScaler` in-place updates gradients thus it's necessary to duplicate gradients. + grads = [ + [torch.randn_like(p) for p in params] + for _ in range(steps_warmup + steps_train) + ] + with torch.no_grad(): + grads_control = [[g.clone() for g in gs] for gs in grads] + grads_graphed = [[g.clone() for g in gs] for gs in grads] + + # Gradient Scaler + scaler_for_control = torch.cuda.amp.GradScaler(init_scale=128.0) + with torch.no_grad(): + scaler_for_control._lazy_init_scale_growth_tracker( + torch.device("cuda") + ) + + scaler_for_graphed = torch.cuda.amp.GradScaler() + scaler_for_graphed.load_state_dict(scaler_for_control.state_dict()) + with torch.no_grad(): + scaler_for_graphed._lazy_init_scale_growth_tracker( + torch.device("cuda") + ) + + # Control (capturable=False) + if has_capturable_arg: + kwargs["capturable"] = False + opt = OptClass(params_control, **kwargs) + + for i in range(steps_warmup + steps_train): + for j, p in enumerate(params_control): + p.grad = grads_control[i][j] + scaler_for_control.step(opt) + scaler_for_control.update() + + # capturable=True + if has_capturable_arg: + kwargs["capturable"] = True + opt = OptClass(params_graphed, **kwargs) + + for i in range(steps_warmup): + for j, p in enumerate(params_graphed): + p.grad = grads_graphed[i][j] + scaler_for_graphed.step(opt) + scaler_for_graphed.update() + + if actually_do_graphs: + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + scaler_for_graphed.step(opt) + scaler_for_graphed.update() + + for i in range(steps_train): + if actually_do_graphs: + for j, p in enumerate(params_graphed): + p.grad.copy_(grads_graphed[i + steps_warmup][j]) + g.replay() + else: + # Passing capturable=True to the constructor and running without graphs should still be + # numerically correct, even if it's not ideal for performance. + for j, p in enumerate(params_graphed): + p.grad = grads_graphed[i + steps_warmup][j] + scaler_for_graphed.step(opt) + scaler_for_graphed.update() + + for p_control, p_graphed in zip(params_control, params_graphed): + self.assertEqual(p_control, p_graphed) + @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) @@ -4480,175 +4698,10 @@ def test_no_triton_on_import(self): self.assertEqual(rc, "False", "Triton was imported when importing torch!") -@torch.testing._internal.common_utils.markDynamoStrictTest class TestCudaOptims(TestCase): # These tests will be instantiate with instantiate_device_type_tests # to apply the new OptimizerInfo structure. - @onlyCUDA - @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >=5.3 required for graphs" - ) - @optims( - [optim for optim in optim_db if optim.has_capturable_arg], - dtypes=[torch.float32], - ) - def test_graph_optims(self, device, dtype, optim_info): - optim_cls = optim_info.optim_cls - all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( - device, dtype, optim_info, skip=("differentiable",) - ) - - steps_warmup = 3 - steps_train = 2 - - for optim_input in all_optim_inputs: - kwargs = optim_input.kwargs - - # lr as a Tensor is not supported when capturable=False and foreach=True for torch.optim.adam - # and torch.optim.adamw - kwargs["lr"] = 0.1 - - for actually_do_graphs in (True, False): - params = [ - torch.randn((i + 5, i + 5), device=device) for i in range(2) - ] + [torch.randn((), device=device)] - params_control = [p.clone().requires_grad_() for p in params] - params_graphed = [p.clone().requires_grad_() for p in params] - - grads = [ - [torch.randn_like(p) for p in params] - for _ in range(steps_warmup + steps_train) - ] - - # Control (capturable=False) - kwargs["capturable"] = False - - opt = optim_cls(params_control, **kwargs) - for i in range(steps_warmup + steps_train): - for j, p in enumerate(params_control): - p.grad = grads[i][j] - opt.step() - - # capturable=True - kwargs["capturable"] = True - opt = optim_cls(params_graphed, **kwargs) - - for i in range(steps_warmup): - for j, p in enumerate(params_graphed): - p.grad = grads[i][j] - opt.step() - - if actually_do_graphs: - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - opt.step() - - for i in range(steps_train): - if actually_do_graphs: - for j, p in enumerate(params_graphed): - p.grad.copy_(grads[i + steps_warmup][j]) - g.replay() - else: - # Passing capturable=True to the constructor and running without graphs should still be - # numerically correct, even if it's not ideal for performance. - for j, p in enumerate(params_graphed): - p.grad = grads[i + steps_warmup][j] - opt.step() - - for p_control, p_graphed in zip(params_control, params_graphed): - self.assertEqual(p_control, p_graphed) - - @onlyCUDA - @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" - ) - @optims( - [optim for optim in optim_db if "fused" in optim.supported_impls], - dtypes=[torch.float32], - ) - def test_graph_scaling_fused_optimizers(self, device, dtype, optim_info): - optim_cls = optim_info.optim_cls - - steps_warmup = 3 - steps_train = 2 - - optim_inputs = optim_info.optim_inputs_func(device=device) - - for optim_input in optim_inputs: - kwargs = optim_input.kwargs - kwargs["fused"] = True - - for actually_do_graphs in ( - (True, False) if optim_info.has_capturable_arg else (True,) - ): - params = [torch.randn((i + 5, i + 5), device=device) for i in range(2)] - params_control = [p.clone().requires_grad_() for p in params] - params_graphed = [p.clone().requires_grad_() for p in params] - - # `GradScaler` in-place updates gradients thus it's necessary to duplicate gradients. - grads = [ - [torch.randn_like(p) for p in params] - for _ in range(steps_warmup + steps_train) - ] - with torch.no_grad(): - grads_control = [[g.clone() for g in gs] for gs in grads] - grads_graphed = [[g.clone() for g in gs] for gs in grads] - - # Gradient Scaler - scaler_for_control = torch.cuda.amp.GradScaler(init_scale=128.0) - with torch.no_grad(): - scaler_for_control._lazy_init_scale_growth_tracker(device) - - scaler_for_graphed = torch.cuda.amp.GradScaler() - scaler_for_graphed.load_state_dict(scaler_for_control.state_dict()) - with torch.no_grad(): - scaler_for_graphed._lazy_init_scale_growth_tracker(device) - - # Control (capturable=False) - if optim_info.has_capturable_arg: - kwargs["capturable"] = False - opt = optim_cls(params_control, **kwargs) - - for i in range(steps_warmup + steps_train): - for j, p in enumerate(params_control): - p.grad = grads_control[i][j] - scaler_for_control.step(opt) - scaler_for_control.update() - - # capturable=True - if optim_info.has_capturable_arg: - kwargs["capturable"] = True - opt = optim_cls(params_graphed, **kwargs) - - for i in range(steps_warmup): - for j, p in enumerate(params_graphed): - p.grad = grads_graphed[i][j] - scaler_for_graphed.step(opt) - scaler_for_graphed.update() - - if actually_do_graphs: - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - scaler_for_graphed.step(opt) - scaler_for_graphed.update() - - for i in range(steps_train): - if actually_do_graphs: - for j, p in enumerate(params_graphed): - p.grad.copy_(grads_graphed[i + steps_warmup][j]) - g.replay() - else: - # Passing capturable=True to the constructor and running without graphs should still be - # numerically correct, even if it's not ideal for performance. - for j, p in enumerate(params_graphed): - p.grad = grads_graphed[i + steps_warmup][j] - scaler_for_graphed.step(opt) - scaler_for_graphed.update() - - for p_control, p_graphed in zip(params_control, params_graphed): - self.assertEqual(p_control, p_graphed) - @onlyCUDA @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py index 5abacf2df1d61..c81efb093cd8b 100644 --- a/torch/testing/_internal/common_optimizers.py +++ b/torch/testing/_internal/common_optimizers.py @@ -123,8 +123,6 @@ def __init__( supported_impls: Tuple[str] = ("foreach", "differentiable"), # the optim supports passing in sparse gradients as well as dense grads supports_sparse: bool = False, - # the optimizer constructor supports passing in capturable as a kwarg - has_capturable_arg: bool = False, # the optim only supports one config: sparse grads w/ dense params, see SparseAdam only_supports_sparse_grads: bool = False, # Tuple of (optimizer kwargs, schedulers_constructors) specifically for sparse tests, @@ -149,7 +147,6 @@ def __init__( self.scheduler_inputs = scheduler_inputs self.supported_impls = supported_impls self.supports_sparse = supports_sparse - self.has_capturable_arg = has_capturable_arg self.metadata_for_sparse = metadata_for_sparse self.only_supports_sparse_grads = only_supports_sparse_grads self.supports_complex = supports_complex @@ -314,11 +311,10 @@ def optim_inputs_func_adadelta(device, dtype=None): OptimizerInput( params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay" ), - OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"), OptimizerInput( params=None, kwargs={"weight_decay": 0.1, "maximize": True}, - desc="maximize, weight_decay", + desc="maximize", ), OptimizerInput( params=None, kwargs={"rho": 0.95, "weight_decay": 0.9}, desc="rho" @@ -530,15 +526,10 @@ def optim_inputs_func_adamax(device, dtype=None): OptimizerInput( params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay" ), - OptimizerInput( - params=None, - kwargs={"maximize": True}, - desc="maximize", - ), OptimizerInput( params=None, kwargs={"weight_decay": 0.1, "maximize": True}, - desc="maximize, weight_decay", + desc="maximize", ), ] + (cuda_supported_configs if "cuda" in str(device) else []) @@ -690,22 +681,16 @@ def optim_inputs_func_nadam(device, dtype=None): kwargs={"momentum_decay": 6e-3}, desc="non-zero momentum_decay", ), - OptimizerInput( - params=None, - kwargs={ - "weight_decay": 0.1, - }, - desc="weight_decay", - ), OptimizerInput( params=None, kwargs={"weight_decay": 0.1, "momentum_decay": 6e-3}, - desc="weight_decay, momentum_decay", + desc="weight_decay", ), OptimizerInput( params=None, kwargs={ "weight_decay": 0.1, + "momentum_decay": 6e-3, "decoupled_weight_decay": True, }, desc="decoupled_weight_decay", @@ -833,26 +818,11 @@ def optim_inputs_func_rmsprop(device, dtype=None): OptimizerInput( params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay" ), - OptimizerInput( - params=None, - kwargs={ - "maximize": True, - }, - desc="maximize", - ), OptimizerInput( params=None, kwargs={"weight_decay": 0.1, "centered": True}, desc="centered", ), - OptimizerInput( - params=None, - kwargs={ - "maximize": True, - "weight_decay": 0.1, - }, - desc="maximize, weight_decay", - ), OptimizerInput( params=None, kwargs={"weight_decay": 0.1, "centered": True, "momentum": 0.1}, @@ -866,7 +836,7 @@ def optim_inputs_func_rmsprop(device, dtype=None): "momentum": 0.1, "maximize": True, }, - desc="maximize, centered, weight_decay, w/ momentum", + desc="maximize", ), ] + (cuda_supported_configs if "cuda" in str(device) else []) @@ -937,15 +907,7 @@ def optim_inputs_func_sgd(device, dtype=None): OptimizerInput( params=None, kwargs={"lr": torch.tensor(0.001)}, desc="tensor lr" ), - OptimizerInput( - params=None, kwargs={"weight_decay": 0.5}, desc="non-zero weight_decay" - ), OptimizerInput(params=None, kwargs={"momentum": 0.9}, desc="momentum"), - OptimizerInput( - params=None, - kwargs={"weight_decay": 0.1, "maximize": True}, - desc="maximize", - ), OptimizerInput( params=None, kwargs={"momentum": 0.9, "dampening": 0.5}, @@ -954,13 +916,18 @@ def optim_inputs_func_sgd(device, dtype=None): OptimizerInput( params=None, kwargs={"momentum": 0.9, "weight_decay": 0.1}, - desc="weight_decay w/ momentum", + desc="non-zero weight_decay", ), OptimizerInput( params=None, kwargs={"momentum": 0.9, "nesterov": True, "weight_decay": 0.1}, desc="nesterov", ), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "maximize": True}, + desc="maximize", + ), ] @@ -1130,7 +1097,6 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_adadelta, optim_error_inputs_func=optim_error_inputs_func_adadelta, supported_impls=("foreach", "differentiable"), - has_capturable_arg=True, skips=( DecorateInfo( skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"), @@ -1266,7 +1232,6 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_error_inputs_func=optim_error_inputs_func_adam, supported_impls=("foreach", "differentiable", "fused"), supports_fused_on=("cpu", "cuda"), - has_capturable_arg=True, decorators=( # Expected floating point error between fused and compiled forloop DecorateInfo( @@ -1333,7 +1298,6 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_adamax, optim_error_inputs_func=optim_error_inputs_func_adamax, supported_impls=("foreach", "differentiable"), - has_capturable_arg=True, skips=( DecorateInfo( skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115 @@ -1384,7 +1348,6 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_error_inputs_func=optim_error_inputs_func_adamw, supported_impls=("foreach", "differentiable", "fused"), supports_fused_on=("cpu", "cuda"), - has_capturable_arg=True, decorators=( # Expected error between compiled forloop and fused optimizers DecorateInfo( @@ -1451,7 +1414,6 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_asgd, optim_error_inputs_func=optim_error_inputs_func_asgd, supported_impls=("foreach", "differentiable"), - has_capturable_arg=True, skips=( DecorateInfo( skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"), @@ -1544,7 +1506,6 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_nadam, optim_error_inputs_func=optim_error_inputs_func_nadam, supported_impls=("foreach", "differentiable"), - has_capturable_arg=True, skips=( DecorateInfo( skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115 @@ -1600,7 +1561,6 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_radam, optim_error_inputs_func=optim_error_inputs_func_radam, supported_impls=("foreach", "differentiable"), - has_capturable_arg=True, skips=( DecorateInfo( skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"), @@ -1646,7 +1606,6 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_rmsprop, optim_error_inputs_func=optim_error_inputs_func_rmsprop, supported_impls=("foreach", "differentiable"), - has_capturable_arg=True, skips=( DecorateInfo( skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115 @@ -1696,7 +1655,6 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_rprop, optim_error_inputs_func=optim_error_inputs_func_rprop, supported_impls=("foreach", "differentiable"), - has_capturable_arg=True, skips=( DecorateInfo( skipIfMps, # Rprop doesn't update for non-contiguous, see #118117 From 2f53747ec6157f8a4169cd01355ea91979b3dc40 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Mon, 20 May 2024 12:39:51 +0000 Subject: [PATCH 04/35] Speedup bf16 gemm fallback on ARM (#126592) By dispatching it to multiple threads and using vectorized dot operation (with fp16 to fp32 upcasts via left shift) This bumps stories110M eval from 22 to 55 tokens/sec using bfloat16 TODO: - Refactor tinygemm template and use it here Pull Request resolved: https://github.com/pytorch/pytorch/pull/126592 Approved by: https://github.com/mikekgfb --- aten/src/ATen/native/cpu/BlasKernel.cpp | 47 ++++++++++++++++++++++--- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/cpu/BlasKernel.cpp b/aten/src/ATen/native/cpu/BlasKernel.cpp index 1cc53da3584ea..587809ea57c8d 100644 --- a/aten/src/ATen/native/cpu/BlasKernel.cpp +++ b/aten/src/ATen/native/cpu/BlasKernel.cpp @@ -308,7 +308,18 @@ void gemm_notrans_( } -static float compute_dot(const float16_t *a, const float16_t *b, int64_t l) { +inline float32x4_t load_as_float32x4(const Half* ptr) { + return vcvt_f32_f16(vld1_f16(reinterpret_cast(ptr))); +} + +inline float32x4_t load_as_float32x4(const BFloat16* ptr) { + int32x4_t shift = vdupq_n_s32(16); + uint32x4_t as_int = vmovl_u16(vld1_u16(reinterpret_cast(ptr))); + return vreinterpretq_f32_u32(vshlq_u32(as_int, shift)); +} + +template +static float compute_dot(const T* a, const T* b, int64_t l) { if ((l&3) != 0) { return sum(l, [&](int64_t i) -> float { return float(a[i]) * float(b[i]); @@ -316,8 +327,8 @@ static float compute_dot(const float16_t *a, const float16_t *b, int64_t l) { } float32x4_t rcv = vdupq_n_f32(0); for (int64_t idx = 0; idx < l; idx += 4) { - float32x4_t aVec = vcvt_f32_f16(vld1_f16(a + idx)); - float32x4_t bVec = vcvt_f32_f16(vld1_f16(b + idx)); + float32x4_t aVec = load_as_float32x4(a + idx); + float32x4_t bVec = load_as_float32x4(b + idx); rcv = vaddq_f32(rcv, vmulq_f32(aVec, bVec)); } auto sum = vpaddq_f32(rcv, rcv); @@ -343,7 +354,35 @@ void gemm_transa_( for (const auto i : c10::irange(begin, end)) { const auto *b_ = b; for (const auto j : c10::irange(n)) { - const auto dot = compute_dot(reinterpret_cast(a_), reinterpret_cast(b_), k); + const auto dot = compute_dot(a_, b_, k); + b_ += ldb; + if (beta == 0) { + c[j*ldc+i] = alpha*dot; + } else { + c[j*ldc+i] = beta*c[j*ldc+i]+alpha*dot; + } + } + a_ += lda; + } + }); +} + +template <> +void gemm_transa_( + TransposeType transa, + int64_t m, int64_t n, int64_t k, + float alpha, + const at::BFloat16 *a, int64_t lda, + const at::BFloat16 *b, int64_t ldb, + float beta, + at::BFloat16 *c, int64_t ldc) { + // c = alpha * (a.T @ b) + beta * c + parallel_for(0, m, 1, [&](int64_t begin, int64_t end) { + const auto *a_ = a + begin * lda; + for (const auto i : c10::irange(begin, end)) { + const auto *b_ = b; + for (const auto j : c10::irange(n)) { + const auto dot = compute_dot(a_, b_, k); b_ += ldb; if (beta == 0) { c[j*ldc+i] = alpha*dot; From 3642e51ea527e23ded10afc266f298b0cb5350c8 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Mon, 20 May 2024 14:04:02 +0800 Subject: [PATCH 05/35] [Quant][PT2E] enable qlinear post op fusion for dynamic quant & qat (#122667) **Description** Add fusion path for dynamic quant and for QAT. The following patterns can be matched for static quant with QAT cases: `qx -> qlinear -> add -> optional relu -> optional type convert -> optional quant` The following patterns can be matched for dynamic quant cases: `qx -> qlinear -> add -> optional relu` **Test plan** python test/inductor/test_mkldnn_pattern_matcher.py -k test_qlinear python test/inductor/test_cpu_cpp_wrapper.py -k test_qlinear python test/test_quantization.py -k test_linear_unary python test/test_quantization.py -k test_linear_binary Pull Request resolved: https://github.com/pytorch/pytorch/pull/122667 Approved by: https://github.com/jgong5 --- test/inductor/test_mkldnn_pattern_matcher.py | 8 +- .../pt2e/test_x86inductor_quantizer.py | 395 ++++++++++++------ .../quantizer/x86_inductor_quantizer.py | 6 +- 3 files changed, 274 insertions(+), 135 deletions(-) diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 756de35df84cf..9c39f1c140018 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -1688,10 +1688,13 @@ def matcher_check_fn(): to_bf16_after_binary = 2 * (add_fn == add_fn_list[2] and fq_x2) self.assertEqual( counters["inductor"]["qlinear_binary_matcher_nodes"], - 5 + 2 * use_relu + to_bf16_after_binary, + (4 if is_dynamic else 5) + 2 * use_relu + to_bf16_after_binary, ) - for is_qat in [False, True]: + is_qat_list = [False, True] + is_dynamic_list = [False, True] + cases = itertools.product(is_qat_list, is_dynamic_list) + for is_qat, is_dynamic in cases: self._test_common( mod, (v,), @@ -1699,6 +1702,7 @@ def matcher_check_fn(): check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, matcher_check_fn=matcher_check_fn, is_qat=is_qat, + is_dynamic=is_dynamic, ) @skipIfNoDynamoSupport diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index 218b30bd9e33f..3202900a28624 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -1198,84 +1198,63 @@ def test_linear(self): node_list, ) - @skipIfNoX86 - def test_linear_unary(self): + def _test_linear_unary_helper( + self, + post_op_module, + post_op_aten, + post_op_aten_inplace, + post_op_algo_list=None, + is_qat=False, + is_dynamic=False, + ): """ Test pattern of linear with unary post ops (e.g. relu) with X86InductorQuantizer. """ use_bias_list = [True, False] inplace_list = [True, False] - postop_list = [nn.ReLU, nn.LeakyReLU] # only test two to save time - cases = itertools.product(use_bias_list, inplace_list, postop_list) - post_op_map = { - nn.ReLU: [torch.ops.aten.relu_.default, torch.ops.aten.relu.default], - nn.LeakyReLU: [ - torch.ops.aten.leaky_relu_.default, - torch.ops.aten.leaky_relu.default, - ], - } + if post_op_algo_list is None: + post_op_algo_list = [None] + cases = itertools.product(use_bias_list, inplace_list, post_op_algo_list) with override_quantized_engine("x86"), torch.no_grad(): - for use_bias, inplace, postop in cases: + for use_bias, inplace, post_op_algo in cases: + if inplace and post_op_aten_inplace is None: + continue m = TestHelperModules.LinearUnaryModule( - use_bias=use_bias, postop=postop, inplace_postop=inplace + use_bias=use_bias, + postop=post_op_module, + inplace_postop=inplace, + post_op_algo=post_op_algo, ).eval() example_inputs = (torch.randn(2, 4),) quantizer = X86InductorQuantizer().set_global( - xiq.get_default_x86_inductor_quantization_config() + xiq.get_default_x86_inductor_quantization_config( + is_qat=is_qat, + is_dynamic=is_dynamic, + ) ) - node_occurrence = { - # one for input and weight of the conv, one for output for the relu - torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, - # quantize_per_channel for weights are const propagated - torch.ops.quantized_decomposed.quantize_per_channel.default: 0, - torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, - } - node_list = [ - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - torch.ops.aten.linear.default, - post_op_map[postop][0 if inplace else 1], - ] - self._test_quantizer( - m, - example_inputs, - quantizer, - node_occurrence, - node_list, + quantize_per_tensor_op = ( + torch.ops.quantized_decomposed.quantize_per_tensor.tensor + if is_dynamic + else torch.ops.quantized_decomposed.quantize_per_tensor.default ) - - @skipIfNoX86 - def test_linear_unary_gelu(self): - """ - Test pattern of linear with unary post ops (e.g. gelu) with X86InductorQuantizer. - """ - use_bias_list = [True, False] - postop = nn.GELU - post_op_algorithm = ["none", "tanh"] - cases = itertools.product(use_bias_list, post_op_algorithm) - with override_quantized_engine("x86"), torch.no_grad(): - for use_bias, post_op_algo in cases: - m = TestHelperModules.LinearUnaryModule( - use_bias=use_bias, postop=postop, post_op_algo=post_op_algo - ).eval() - example_inputs = (torch.randn(2, 4),) - quantizer = X86InductorQuantizer().set_global( - xiq.get_default_x86_inductor_quantization_config() + dequantize_per_tensor_op = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor + if is_dynamic + else torch.ops.quantized_decomposed.dequantize_per_tensor.default ) node_occurrence = { - # one for input and weight of the conv, one for output for the gelu - torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, + # one for input of the linear + quantize_per_tensor_op: 1, + dequantize_per_tensor_op: 1, # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } node_list = [ - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, + quantize_per_tensor_op, + dequantize_per_tensor_op, torch.ops.aten.linear.default, - torch.ops.aten.gelu.default, + post_op_aten_inplace if inplace else post_op_aten, ] self._test_quantizer( m, @@ -1283,8 +1262,71 @@ def test_linear_unary_gelu(self): quantizer, node_occurrence, node_list, + is_qat=is_qat, ) + @skipIfNoX86 + def test_linear_unary(self): + aten = torch.ops.aten + self._test_linear_unary_helper(nn.ReLU, aten.relu.default, aten.relu_.default) + self._test_linear_unary_helper( + nn.LeakyReLU, aten.leaky_relu.default, aten.leaky_relu_.default + ) + self._test_linear_unary_helper( + nn.GELU, aten.gelu.default, None, ["none", "tanh"] + ) + + @skipIfNoX86 + def test_linear_unary_qat(self): + aten = torch.ops.aten + self._test_linear_unary_helper( + nn.ReLU, aten.relu.default, aten.relu_.default, is_qat=True + ) + self._test_linear_unary_helper( + nn.LeakyReLU, aten.leaky_relu.default, aten.leaky_relu_.default, is_qat=True + ) + self._test_linear_unary_helper( + nn.GELU, aten.gelu.default, None, ["none", "tanh"], is_qat=True + ) + + @skipIfNoX86 + def test_linear_unary_dynamic(self): + aten = torch.ops.aten + self._test_linear_unary_helper( + nn.ReLU, aten.relu.default, aten.relu_.default, is_dynamic=True + ) + self._test_linear_unary_helper( + nn.LeakyReLU, + aten.leaky_relu.default, + aten.leaky_relu_.default, + is_dynamic=True, + ) + self._test_linear_unary_helper( + nn.GELU, aten.gelu.default, None, ["none", "tanh"], is_dynamic=True + ) + + @skipIfNoX86 + def test_linear_unary_dynamic_qat(self): + aten = torch.ops.aten + self._test_linear_unary_helper( + nn.ReLU, aten.relu.default, aten.relu_.default, is_qat=True, is_dynamic=True + ) + self._test_linear_unary_helper( + nn.LeakyReLU, + aten.leaky_relu.default, + aten.leaky_relu_.default, + is_qat=True, + is_dynamic=True, + ) + self._test_linear_unary_helper( + nn.GELU, + aten.gelu.default, + None, + ["none", "tanh"], + is_qat=True, + is_dynamic=True, + ) + def _check_annotation_stat(self, gm, expected_stat_dict): # Check expected annotation statistics to ensure the annotation is correct @@ -1302,8 +1344,7 @@ def _check_annotation(node): for op_stat in expected_stat_dict.values(): assert all(v == 0 for v in op_stat.values()) - @skipIfNoX86 - def test_linear_binary(self): + def _test_linear_binary_helper(self, is_qat=False, is_dynamic=False): """ Test pattern of linear with binary post ops (such as add) with X86InductorQuantizer. Currently, only add as binary post op is supported. @@ -1313,7 +1354,20 @@ def test_linear_binary(self): inplace_add_list = [False] example_inputs = (torch.randn(2, 16),) quantizer = X86InductorQuantizer().set_global( - xiq.get_default_x86_inductor_quantization_config() + xiq.get_default_x86_inductor_quantization_config( + is_qat=is_qat, + is_dynamic=is_dynamic, + ) + ) + quantize_per_tensor_op = ( + torch.ops.quantized_decomposed.quantize_per_tensor.tensor + if is_dynamic + else torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + dequantize_per_tensor_op = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor + if is_dynamic + else torch.ops.quantized_decomposed.dequantize_per_tensor.default ) cases = itertools.product(linear_pos_list, inplace_add_list) with override_quantized_engine("x86"), torch.no_grad(): @@ -1325,26 +1379,28 @@ def test_linear_binary(self): node_occurrence = { # Only one 1 q-dq for input of the linear # No q-dq for extra input node of add - torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, + quantize_per_tensor_op: 1, + dequantize_per_tensor_op: 1, # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } else: + # convert_pt2e disables duplicate dequant for dynamic quant + num_dequant = 1 if is_dynamic else 2 node_occurrence = { # One quantize_per_tensor for both linear nodes (shared) # Two dequantize_per_tensor for two linear nodes # No q-dq for extra input node of add - torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + quantize_per_tensor_op: 1, + dequantize_per_tensor_op: num_dequant, # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } node_list = [ - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, + quantize_per_tensor_op, + dequantize_per_tensor_op, torch.ops.aten.linear.default, torch.ops.aten.add_.Tensor if inplace_add @@ -1356,6 +1412,7 @@ def test_linear_binary(self): quantizer, node_occurrence, node_list, + is_qat=is_qat, )[-1] # One linear and add are fused. The other linear is quantized alone if present aten = torch.ops.aten @@ -1369,6 +1426,22 @@ def test_linear_binary(self): } self._check_annotation_stat(fq_m, expected_annotation_stat) + @skipIfNoX86 + def test_linear_binary(self): + self._test_linear_binary_helper() + + @skipIfNoX86 + def test_linear_binary_qat(self): + self._test_linear_binary_helper(is_qat=True) + + @skipIfNoX86 + def test_linear_binary_dynamic(self): + self._test_linear_binary_helper(is_dynamic=True) + + @skipIfNoX86 + def test_linear_binary_dynamic_qat(self): + self._test_linear_binary_helper(is_qat=True, is_dynamic=True) + @skipIfNoX86 def test_linear_binary2(self): """ @@ -1379,28 +1452,43 @@ def test_linear_binary2(self): Since linear_1 has 2 users, we should annotate linear_2 for binary fusion instead of linear_1 """ example_inputs = (torch.randn(2, 16),) - quantizer = X86InductorQuantizer().set_global( - xiq.get_default_x86_inductor_quantization_config() - ) # TODO test for inplace add after refactoring of capture_pre_autograd_graph inplace_add_list = [False] + is_qat_list = [False, True] + is_dynamic_list = [False, True] + cases = itertools.product(inplace_add_list, is_qat_list, is_dynamic_list) with override_quantized_engine("x86"), torch.no_grad(): - for inplace_add in inplace_add_list: + for inplace_add, is_qat, is_dynamic in cases: + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config( + is_qat=is_qat, is_dynamic=is_dynamic + ) + ) m = TestHelperModules.LinearAddModule2(inplace_add=inplace_add).eval() + quantize_per_tensor_op = ( + torch.ops.quantized_decomposed.quantize_per_tensor.tensor + if is_dynamic + else torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + dequantize_per_tensor_op = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor + if is_dynamic + else torch.ops.quantized_decomposed.dequantize_per_tensor.default + ) # Two q-dq nodes for inputs of linear nodes # No q-dq for extra input node of add node_occurrence = { - torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + quantize_per_tensor_op: 2, + dequantize_per_tensor_op: 2, # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } node_list = [ - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, + quantize_per_tensor_op, + dequantize_per_tensor_op, torch.ops.aten.linear.default, - torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_channel.default, torch.ops.aten.add_.Tensor if inplace_add else torch.ops.aten.add.Tensor, @@ -1425,7 +1513,7 @@ def test_linear_binary2(self): self._check_annotation_stat(fq_m, expected_annotation_stat) @skipIfNoX86 - def test_linear_binary_unary(self): + def _test_linear_binary_unary_helper(self, is_qat=False, is_dynamic=False): """ Test pattern of linear with binary + unary post ops (such as add + relu) with X86InductorQuantizer. Currently, only add as binary post op and relu as unary post op are supported. @@ -1437,7 +1525,20 @@ def test_linear_binary_unary(self): inplace_relu_list = [False] example_inputs = (torch.randn(2, 16),) quantizer = X86InductorQuantizer().set_global( - xiq.get_default_x86_inductor_quantization_config() + xiq.get_default_x86_inductor_quantization_config( + is_qat=is_qat, + is_dynamic=is_dynamic, + ) + ) + quantize_per_tensor_op = ( + torch.ops.quantized_decomposed.quantize_per_tensor.tensor + if is_dynamic + else torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + dequantize_per_tensor_op = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor + if is_dynamic + else torch.ops.quantized_decomposed.dequantize_per_tensor.default ) cases = itertools.product(linear_pos_list, inplace_add_list, inplace_relu_list) with override_quantized_engine("x86"), torch.no_grad(): @@ -1451,26 +1552,28 @@ def test_linear_binary_unary(self): node_occurrence = { # Only one q-dq node for input of the linear # No q-dq node for extra input node of add - torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, + quantize_per_tensor_op: 1, + dequantize_per_tensor_op: 1, # note: quantize op for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } else: + # convert_pt2e disables duplicate dequant for dynamic quant + num_dequant = 1 if is_dynamic else 2 node_occurrence = { # One quantize_per_tensor for both linear nodes (shared) # Two dequantize_per_tensor for two linear nodes # No q-dq for extra input node of add - torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + quantize_per_tensor_op: 1, + dequantize_per_tensor_op: num_dequant, # note: quantize op for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } node_list = [ - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, + quantize_per_tensor_op, + dequantize_per_tensor_op, torch.ops.aten.linear.default, torch.ops.aten.add_.Tensor if inplace_add @@ -1498,57 +1601,91 @@ def test_linear_binary_unary(self): } self._check_annotation_stat(fq_m, expected_annotation_stat) + @skipIfNoX86 + def test_linear_binary_unary(self): + self._test_linear_binary_unary_helper() + + @skipIfNoX86 + def test_linear_binary_unary_qat(self): + self._test_linear_binary_unary_helper(is_qat=True) + + @skipIfNoX86 + def test_linear_binary_unary_dynamic(self): + self._test_linear_binary_unary_helper(is_dynamic=True) + + @skipIfNoX86 + def test_linear_binary_unary_dynamic_qat(self): + self._test_linear_binary_unary_helper(is_qat=True, is_dynamic=True) + @skipIfNoX86 def test_linear_binary_unary_serials(self): """ Test pattern of 2 following up linear add relu with X86InductorQuantizer. """ + is_qat_list = [False, True] + is_dynamic_list = [False, True] + cases = itertools.product(is_qat_list, is_dynamic_list) with override_quantized_engine("x86"), torch.no_grad(): - m = TestHelperModules.SerialsLinearAddReLUModule().eval() - example_inputs = (torch.randn(2, 16),) - quantizer = X86InductorQuantizer().set_global( - xiq.get_default_x86_inductor_quantization_config() - ) - node_occurrence = { - # quantize_per_tensor: 1 for linear_1, 1 for linear_2/3 (shared), 1 for linear_4 - # dequantize_per_tensor: 1 for each linear - # No q-dq for extra input node of add - torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4, - # quantize_per_channel for weights are const propagated - torch.ops.quantized_decomposed.quantize_per_channel.default: 0, - torch.ops.quantized_decomposed.dequantize_per_channel.default: 4, - } - node_list = [ - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - torch.ops.aten.linear.default, - torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - torch.ops.aten.linear.default, - torch.ops.aten.linear.default, - torch.ops.aten.add.Tensor, - torch.ops.aten.relu.default, - ] - fq_m = self._test_quantizer( - m, - example_inputs, - quantizer, - node_occurrence, - node_list, - )[-1] - # Two linear nodes are quantized alone - # The other two are fused with add and relu - aten = torch.ops.aten - expected_annotation_stat = { - aten.linear.default: { - "annotated": 4, - "is_quant_out": 2, - }, - aten.add.Tensor: {"annotated": 2, "is_quant_out": 0}, - aten.relu.default: {"annotated": 2, "is_quant_out": 2}, - } - self._check_annotation_stat(fq_m, expected_annotation_stat) + for is_qat, is_dynamic in cases: + m = TestHelperModules.SerialsLinearAddReLUModule().eval() + example_inputs = (torch.randn(2, 16),) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config( + is_qat=is_qat, + is_dynamic=is_dynamic, + ) + ) + quantize_per_tensor_op = ( + torch.ops.quantized_decomposed.quantize_per_tensor.tensor + if is_dynamic + else torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + dequantize_per_tensor_op = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor + if is_dynamic + else torch.ops.quantized_decomposed.dequantize_per_tensor.default + ) + # convert_pt2e disables duplicate dequant for dynamic quant + num_dequant = 3 if is_dynamic else 4 + node_occurrence = { + # quantize_per_tensor: 1 for linear_1, 1 for linear_2/3 (shared), 1 for linear_4 + # dequantize_per_tensor: 1 for each linear + # No q-dq for extra input node of add + quantize_per_tensor_op: 3, + dequantize_per_tensor_op: num_dequant, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 4, + } + node_list = [ + quantize_per_tensor_op, + dequantize_per_tensor_op, + torch.ops.aten.linear.default, + torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.aten.linear.default, + torch.ops.aten.linear.default, + torch.ops.aten.add.Tensor, + torch.ops.aten.relu.default, + ] + fq_m = self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + )[-1] + # Two linear nodes are quantized alone + # The other two are fused with add and relu + aten = torch.ops.aten + expected_annotation_stat = { + aten.linear.default: { + "annotated": 4, + "is_quant_out": 2, + }, + aten.add.Tensor: {"annotated": 2, "is_quant_out": 0}, + aten.relu.default: {"annotated": 2, "is_quant_out": 2}, + } + self._check_annotation_stat(fq_m, expected_annotation_stat) @skipIfTorchDynamo("very slow") @skipIfNoX86 diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 4cc05e46c6a70..ecb9a14c0a4c6 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -776,10 +776,8 @@ def _annotate_conv2d_fusion_pattern(self, model: torch.fx.GraphModule): def _annotate_linear_fusion_pattern(self, model: torch.fx.GraphModule): if config := self._get_aten_operator_qconfig(torch.ops.aten.linear.default): - if config.input_activation and not config.input_activation.is_dynamic: - # Weiwen: Dynamic Quant of linear unary will be supported in next step - self._annotate_linear_binary_unary(model, config) - self._annotate_linear_unary(model, config) + self._annotate_linear_binary_unary(model, config) + self._annotate_linear_unary(model, config) self._annotate_linear(model, config) def _annotate_matmul(self, model: torch.fx.GraphModule): From 7aa853a54e0731601b3f2191e9f45bd0f3374eef Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Mon, 20 May 2024 16:39:14 +0000 Subject: [PATCH 06/35] [CI] Install sccache on XLA build job (#126117) XLA build job uses a docker image from XLA, which doesn't have sccache installed. The XLA build job just builds pytorch, XLA gets built during the test job. The pytorch build was taking 1+hrs, with a warm cache it takes <30min Pull Request resolved: https://github.com/pytorch/pytorch/pull/126117 Approved by: https://github.com/malfet --- .ci/pytorch/build.sh | 3 +++ .ci/pytorch/install_cache_xla.sh | 37 ++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) create mode 100755 .ci/pytorch/install_cache_xla.sh diff --git a/.ci/pytorch/build.sh b/.ci/pytorch/build.sh index 4aa5dc39d0f5f..46f91f71283ff 100755 --- a/.ci/pytorch/build.sh +++ b/.ci/pytorch/build.sh @@ -289,6 +289,9 @@ else fi WERROR=1 python setup.py bdist_wheel else + if [[ "$BUILD_ENVIRONMENT" == *xla* ]]; then + source .ci/pytorch/install_cache_xla.sh + fi python setup.py bdist_wheel fi pip_install_whl "$(echo dist/*.whl)" diff --git a/.ci/pytorch/install_cache_xla.sh b/.ci/pytorch/install_cache_xla.sh new file mode 100755 index 0000000000000..bfc2da177f6ed --- /dev/null +++ b/.ci/pytorch/install_cache_xla.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +# Script for installing sccache on the xla build job, which uses xla's docker +# image and doesn't have sccache installed on it. This is mostly copied from +# .ci/docker/install_cache.sh. Changes are: removing checks that will always +# return the same thing, ex checks for for rocm, CUDA, and changing the path +# where sccache is installed, and not changing /etc/environment. + +set -ex + +install_binary() { + echo "Downloading sccache binary from S3 repo" + curl --retry 3 https://s3.amazonaws.com/ossci-linux/sccache -o /tmp/cache/bin/sccache +} + +mkdir -p /tmp/cache/bin +mkdir -p /tmp/cache/lib +export PATH="/tmp/cache/bin:$PATH" + +install_binary +chmod a+x /tmp/cache/bin/sccache + +function write_sccache_stub() { + # Unset LD_PRELOAD for ps because of asan + ps issues + # https://gcc.gnu.org/bugzilla/show_bug.cgi?id=90589 + # shellcheck disable=SC2086 + # shellcheck disable=SC2059 + printf "#!/bin/sh\nif [ \$(env -u LD_PRELOAD ps -p \$PPID -o comm=) != sccache ]; then\n exec sccache $(which $1) \"\$@\"\nelse\n exec $(which $1) \"\$@\"\nfi" > "/tmp/cache/bin/$1" + chmod a+x "/tmp/cache/bin/$1" +} + +write_sccache_stub cc +write_sccache_stub c++ +write_sccache_stub gcc +write_sccache_stub g++ +write_sccache_stub clang +write_sccache_stub clang++ From 8c38d0cd648a8cef9518591f3b4dc257104a5fa8 Mon Sep 17 00:00:00 2001 From: Colin Peppler Date: Sat, 18 May 2024 23:22:03 -0700 Subject: [PATCH 07/35] [inductor] Fix edge case in JIT vs. AOT fusion after finalizing MultiTemplateBuffer (#126622) # Context Here's a peripheral scenario causing the JIT-pass and AOT-pass to pick different fusions. ```py # JIT -- buf3 is a MultiTemplateBuffer V.graph.buffers = [buf0, buf1, buf2, buf3, buf4] ^ ^ # JIT pass calls finalize_multi_template_buffers() V.graph.buffers = [buf0, buf1, buf2, buf4, *buf3*] # AOT, note proximity_score(buf2, buf4) is "better" for fusion than JIT V.graph.buffers = [buf0, buf1, buf2, buf4, *buf3*] ^ ^ ``` It happens like this: * JIT starts with the original set nodes using V.graph.buffers * In JIT, finalize_multi_template_buffers() is called which can change the order of the buffers. * This makes the order of buffers/scheduler nodes different. * Now, each node's min/max-order is different than before. * As a result, the proximity between two nodes is different. https://github.com/pytorch/pytorch/blob/ad67553c5c1672d65b810acd7a6a01e11695098b/torch/_inductor/scheduler.py#L2316-L2335 # Error ``` $ TORCH_LOGS="+fusion" python test/inductor/test_max_autotune.py -k test_jit_fusion_matches_aot_fusion ====================================================================== FAIL: test_jit_fusion_matches_aot_fusion (__main__.TestMaxAutotune) ---------------------------------------------------------------------- Traceback (most recent call last): ... File "/data/users/colinpeppler/pytorch/torch/_inductor/graph.py", line 1718, in compile_to_fn code, linemap = self.codegen_with_cpp_wrapper() File "/data/users/colinpeppler/pytorch/torch/_inductor/graph.py", line 1618, in codegen_with_cpp_wrapper return self.codegen() File "/data/users/colinpeppler/pytorch/torch/_inductor/graph.py", line 1636, in codegen self.scheduler.codegen() File "/data/users/colinpeppler/pytorch/torch/_dynamo/utils.py", line 210, in time_wrapper r = func(*args, **kwargs) File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 2602, in codegen self.get_backend(device).codegen_node(node) # type: ignore[possibly-undefined] File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/cuda_combined_scheduling.py", line 66, in codegen_node return self._triton_scheduling.codegen_node(node) File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3377, in codegen_node return self.codegen_node_schedule(node_schedule, buf_accesses, numel, rnumel) File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3602, in codegen_node_schedule final_kernel.call_kernel(final_kernel.kernel_name) File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3055, in call_kernel grid = wrapper.generate_default_grid(name, grid) File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/cpp_wrapper_cuda.py", line 174, in generate_default_grid params is not None AssertionError: cuda kernel parameters for triton_poi_fused_add_0 should already exist at this moment, only found dict_keys(['Placeholder.DESCRIPTIVE_NAME', 'triton_poi_fused_add_mul_0', 'triton_poi_fused_pow_1']) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/126622 Approved by: https://github.com/chenyang78 ghstack dependencies: #125982 --- test/inductor/test_max_autotune.py | 21 +++++++++++++++++++++ torch/_inductor/scheduler.py | 4 +++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index c5f0afa118f87..1859ca391e02a 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -463,6 +463,27 @@ def fn(a, b, c): self.assertEqual(fn(*inputs), fn_c(*inputs), atol=1e-2, rtol=1e-2) self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 0) + @skipIfRocm + @fresh_inductor_cache() + @config.patch(max_autotune=True, max_fusion_size=2) + def test_jit_fusion_matches_aot_fusion(self): + # In this example, AOTInductor's JIT-compile will fuse(buf1, buf2) due + # to proximity, we want to make sure AOT-compile pass does the same. + # AOT could do fuse(buf2, buf4) instead if buf3 was pushed to the end + # of the V.graph.buffers list because fuse(buf2, buf4) would have a + # better proximity score than fuse(buf1, buf2). This scenario is possible + # since finalizing MultiTemplateBuffers needs to replace buffers. + def fn(x, number): + buf0 = x + x + buf1 = number.item() + buf2 = x * x + buf3 = x @ x # MultiTemplateBuffer + buf4 = x**2 + return buf0, buf1, buf2, buf3, buf4 + + inputs = (torch.rand([256, 256], device="cuda"), torch.tensor(3, device="cuda")) + torch._export.aot_compile(fn, args=inputs) + @config.patch(autotune_local_cache=False, autotune_remote_cache=False) def test_precompilations(self): def fn(a, b, c): diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 456e0c50567d5..ec4763160a7b6 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -1752,7 +1752,9 @@ def replace_buffer(orig_node: ir.MultiTemplateBuffer, new_node: ir.Buffer): del V.graph.name_to_buffer[replaced_name] new_node.name = orig_name - V.graph.buffers.remove(orig_node) + orig = V.graph.buffers.index(orig_node) + V.graph.buffers.remove(new_node) + V.graph.buffers[orig] = new_node V.graph.name_to_buffer[orig_name] = new_node for i, node in enumerate(self.nodes): From 655038687afd19a4a4c9371b77ff046fd6c84be1 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Mon, 20 May 2024 17:36:30 +0000 Subject: [PATCH 08/35] [TD] Upload names of failures to s3 for pytest cache (#126315) Some tests don't get run through pytest and pytest crashes when a test segfaults, so in both caess, the pytest cache won't have an entry (similar to https://github.com/pytorch/test-infra/pull/5205). Instead, manually upload/download an extra file that lists the failing test files Technically this would be more general than the pytest cache Pull Request resolved: https://github.com/pytorch/pytorch/pull/126315 Approved by: https://github.com/ZainRizvi --- .github/scripts/pytest_caching_utils.py | 29 +++++++++++++++++++ test/run_test.py | 6 ++++ tools/stats/import_test_stats.py | 15 ++++++++++ .../testing/do_target_determination_for_s3.py | 2 ++ .../heuristics/previously_failed_in_pr.py | 29 ++++++++++++++++++- 5 files changed, 80 insertions(+), 1 deletion(-) diff --git a/.github/scripts/pytest_caching_utils.py b/.github/scripts/pytest_caching_utils.py index b2a71fb4b8e13..e1a581c2d3095 100644 --- a/.github/scripts/pytest_caching_utils.py +++ b/.github/scripts/pytest_caching_utils.py @@ -18,6 +18,7 @@ PYTEST_CACHE_DIR_NAME = ".pytest_cache" BUCKET = "gha-artifacts" LASTFAILED_FILE_PATH = Path("v/cache/lastfailed") +TD_HEURISTIC_PREVIOUSLY_FAILED_ADDITIONAL = "previous_failures_additional.json" # Temp folders ZIP_UPLOAD = "zip-upload" @@ -191,6 +192,10 @@ def _merge_pytest_caches( pytest_cache_dir_to_merge_from, pytest_cache_dir_to_merge_into ) + _merge_additional_failures_files( + pytest_cache_dir_to_merge_from, pytest_cache_dir_to_merge_into + ) + def _merge_lastfailed_files(source_pytest_cache: Path, dest_pytest_cache: Path) -> None: # Simple cases where one of the files doesn't exist @@ -232,3 +237,27 @@ def _merged_lastfailed_content( del to_lastfailed[""] return to_lastfailed + + +def _merge_additional_failures_files( + source_pytest_cache: Path, dest_pytest_cache: Path +) -> None: + # Simple cases where one of the files doesn't exist + source_lastfailed_file = ( + source_pytest_cache / TD_HEURISTIC_PREVIOUSLY_FAILED_ADDITIONAL + ) + dest_lastfailed_file = dest_pytest_cache / TD_HEURISTIC_PREVIOUSLY_FAILED_ADDITIONAL + + if not source_lastfailed_file.exists(): + return + if not dest_lastfailed_file.exists(): + copy_file(source_lastfailed_file, dest_lastfailed_file) + return + + # Both files exist, so we need to merge them + from_lastfailed = load_json_file(source_lastfailed_file) + to_lastfailed = load_json_file(dest_lastfailed_file) + merged_content = list(set(from_lastfailed + to_lastfailed)) + + # Save the results + write_json_file(dest_lastfailed_file, merged_content) diff --git a/test/run_test.py b/test/run_test.py index 71ab08199f7a7..d43a396f1441e 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -59,6 +59,9 @@ ) from tools.testing.do_target_determination_for_s3 import import_results from tools.testing.target_determination.gen_artifact import gen_ci_artifact +from tools.testing.target_determination.heuristics.previously_failed_in_pr import ( + gen_additional_test_failures_file, +) from tools.testing.target_determination.heuristics.utils import get_pr_number from tools.testing.test_run import TestRun @@ -1795,6 +1798,9 @@ def __str__(self): **test_stats, }, ) + gen_additional_test_failures_file( + [test.test_file for test, _ in all_failures] + ) if len(all_failures): for _, err in all_failures: diff --git a/tools/stats/import_test_stats.py b/tools/stats/import_test_stats.py index a256c3ea04c34..513edb12fcfe6 100644 --- a/tools/stats/import_test_stats.py +++ b/tools/stats/import_test_stats.py @@ -28,6 +28,7 @@ def get_disabled_issues() -> List[str]: TD_HEURISTIC_PROFILING_FILE = "td_heuristic_profiling.json" TD_HEURISTIC_HISTORICAL_EDITED_FILES = "td_heuristic_historical_edited_files.json" TD_HEURISTIC_PREVIOUSLY_FAILED = "previous_failures.json" +TD_HEURISTIC_PREVIOUSLY_FAILED_ADDITIONAL = "previous_failures_additional.json" FILE_CACHE_LIFESPAN_SECONDS = datetime.timedelta(hours=3).seconds @@ -165,6 +166,20 @@ def copy_pytest_cache() -> None: ) +def copy_additional_previous_failures() -> None: + original_path = ( + REPO_ROOT / ".pytest_cache" / TD_HEURISTIC_PREVIOUSLY_FAILED_ADDITIONAL + ) + if not original_path.exists(): + return + shutil.copyfile( + original_path, + REPO_ROOT + / ADDITIONAL_CI_FILES_FOLDER + / TD_HEURISTIC_PREVIOUSLY_FAILED_ADDITIONAL, + ) + + def get_from_test_infra_generated_stats( from_file: str, to_file: str, failure_explanation: str ) -> Dict[str, Any]: diff --git a/tools/testing/do_target_determination_for_s3.py b/tools/testing/do_target_determination_for_s3.py index 4b004801fc747..32ea85b980214 100644 --- a/tools/testing/do_target_determination_for_s3.py +++ b/tools/testing/do_target_determination_for_s3.py @@ -8,6 +8,7 @@ sys.path.insert(0, str(REPO_ROOT)) from tools.stats.import_test_stats import ( + copy_additional_previous_failures, copy_pytest_cache, get_td_heuristic_historial_edited_files_json, get_td_heuristic_profiling_json, @@ -51,6 +52,7 @@ def main() -> None: get_td_heuristic_historial_edited_files_json() get_td_heuristic_profiling_json() copy_pytest_cache() + copy_additional_previous_failures() aggregated_heuristics = get_test_prioritizations(selected_tests) diff --git a/tools/testing/target_determination/heuristics/previously_failed_in_pr.py b/tools/testing/target_determination/heuristics/previously_failed_in_pr.py index 8d15d7537e915..81e808d063c3a 100644 --- a/tools/testing/target_determination/heuristics/previously_failed_in_pr.py +++ b/tools/testing/target_determination/heuristics/previously_failed_in_pr.py @@ -6,6 +6,7 @@ from tools.stats.import_test_stats import ( ADDITIONAL_CI_FILES_FOLDER, TD_HEURISTIC_PREVIOUSLY_FAILED, + TD_HEURISTIC_PREVIOUSLY_FAILED_ADDITIONAL, ) from tools.testing.target_determination.heuristics.interface import ( @@ -25,7 +26,7 @@ def __init__(self, **kwargs: Dict[str, Any]): super().__init__(**kwargs) def get_prediction_confidence(self, tests: List[str]) -> TestPrioritizations: - critical_tests = get_previous_failures() + critical_tests = get_previous_failures() | read_additional_test_failures_file() return TestPrioritizations( tests, {TestRun(test): 1 for test in critical_tests if test in tests} ) @@ -54,3 +55,29 @@ def _parse_prev_failing_test_files(last_failed_tests: Dict[str, bool]) -> Set[st prioritized_tests.add(test_file) return prioritized_tests + + +def gen_additional_test_failures_file(tests: List[str]) -> None: + # Segfaults usually result in no xml and some tests don't run through pytest + # (ex doctests). In these cases, there will be no entry in the pytest + # cache, so we should generate a separate file for them and upload it to s3 + # along with the pytest cache + with open( + REPO_ROOT / ".pytest_cache" / TD_HEURISTIC_PREVIOUSLY_FAILED_ADDITIONAL, "w" + ) as f: + json.dump(tests, f, indent=2) + + +def read_additional_test_failures_file() -> Set[str]: + path = ( + REPO_ROOT + / ADDITIONAL_CI_FILES_FOLDER + / TD_HEURISTIC_PREVIOUSLY_FAILED_ADDITIONAL + ) + if not os.path.exists(path): + print(f"could not find path {path}") + return set() + with open(path) as f: + s = set(json.load(f)) + print(f"additional failures: {s}") + return s From 89c1cfe144f2ee9ddeaa2a3156cea69730b130f7 Mon Sep 17 00:00:00 2001 From: angelayi Date: Mon, 20 May 2024 17:42:16 +0000 Subject: [PATCH 09/35] [export] Allow modules to be created in the forward (#125725) Fixes the error in non-strict export when we're tracing a module that initializes another module in its forward function. This appears in [many huggingface models](https://github.com/search?q=repo%3Ahuggingface%2Ftransformers+CrossEntropyLoss%28%29&type=code&fbclid=IwAR285uKvSevJM6SDbXmb4-monj4iH7wf8opkvnec-li7sKpn4lUMjIvbGKc). It's probably not good practice to do this, but since it appears in so many places, and strict-export supports this, we will also support this. The approach we'll take for these cases is that we will inline the call to the module. Parameters and buffers initialized as constants (with `torch.tensor`) will be represented as constant tensors, and those initialized with tensor factory functions (`torch.ones`) will show up as an operator in the graph. The module stack for the ops in the inlined module will reflect the toplevel's module stack. One issue is that strict-export seems to segfault when there is an `nn.Parameter` call in the constructor (https://github.com/pytorch/pytorch/issues/126109). Non-strict export will succeed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125725 Approved by: https://github.com/ydwu4 --- test/export/test_export.py | 154 ++++++++++++++++++++++++++ torch/fx/experimental/proxy_tensor.py | 22 +++- 2 files changed, 174 insertions(+), 2 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 406e1f55dd804..6de95a1b6b406 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -148,6 +148,7 @@ class Inp: NON_STRICT_SUFFIX = "_non_strict" RETRACEABILITY_SUFFIX = "_retraceability" +PREDISPATCH_SUFFIX = "_pre_dispatch" def is_non_strict_test(test_name): @@ -3279,6 +3280,159 @@ def dynamify_inp(x): with self.assertRaisesRegex(RuntimeError, "shape\[0\] to be >= 3, but got 2"): ep.module()(*test_inp) + def test_nested_module(self): + class M1(torch.nn.Module): + def forward(self, x): + return x + x + + class M2(torch.nn.Module): + def forward(self, x): + m = M1() + return m(x) * x + + inps = (torch.randn(3, 3),) + ep = export(M2(), inps) + self.assertTrue(torch.allclose(ep.module()(*inps), M2()(*inps))) + + add_nodes = [ + node + for node in ep.graph.nodes + if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor + ] + self.assertEqual(len(add_nodes), 1) + add_node = add_nodes[0] + self.assertEqual(len(add_node.meta["nn_module_stack"]), 1) + self.assertTrue("M2" in list(add_node.meta["nn_module_stack"].values())[0][1]) + + self.assertExpectedInline( + str(ep.graph).strip(), + """\ +graph(): + %x : [num_users=2] = placeholder[target=x] + %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %x), kwargs = {}) + %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {}) + return (mul,)""", + ) + + unflattened = unflatten(ep) + self.assertTrue(torch.allclose(unflattened(*inps), M2()(*inps))) + + def test_nested_module_with_init_buffer(self): + class M1(torch.nn.Module): + def __init__(self): + super().__init__() + self.b = torch.ones(3, 3) + + def forward(self, x): + return x + self.b + + class M2(torch.nn.Module): + def forward(self, x): + m = M1() + return m(x) * x + + inps = (torch.randn(3, 3),) + ep = export(M2(), inps) + self.assertTrue(torch.allclose(ep.module()(*inps), M2()(*inps))) + + self.assertEqual(len(ep.state_dict), 0) + self.assertEqual(len(ep.constants), 0) + + self.assertExpectedInline( + str(ep.graph).strip(), + """\ +graph(): + %x : [num_users=2] = placeholder[target=x] + %ones : [num_users=1] = call_function[target=torch.ops.aten.ones.default](args = ([3, 3],), kwargs = {device: cpu, pin_memory: False}) + %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %ones), kwargs = {}) + %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {}) + return (mul,)""", + ) + + unflattened = unflatten(ep) + self.assertTrue(torch.allclose(unflattened(*inps), M2()(*inps))) + + @testing.expectedFailureRetraceability # Retracing tensor constants results in buffers + def test_nested_module_with_constant_buffer(self): + class M1(torch.nn.Module): + def __init__(self): + super().__init__() + self.b = torch.tensor(5) + + def forward(self, x): + return x + self.b + + class M2(torch.nn.Module): + def forward(self, x): + m = M1() + return m(x) * x + + inps = (torch.randn(3, 3),) + ep = export(M2(), inps) + self.assertTrue(torch.allclose(ep.module()(*inps), M2()(*inps))) + + self.assertEqual(len(ep.state_dict), 0) + self.assertEqual(len(ep.constants), 1) + + self.assertExpectedInline( + str(ep.graph).strip(), + """\ +graph(): + %c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0] + %x : [num_users=2] = placeholder[target=x] + %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_lifted_tensor_0,), kwargs = {}) + %detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%lift_fresh_copy,), kwargs = {}) + %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %detach), kwargs = {}) + %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {}) + return (mul,)""", + ) + + unflattened = unflatten(ep) + self.assertTrue(torch.allclose(unflattened(*inps), M2()(*inps))) + + def test_nested_module_with_parameter(self): + class M1(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.nn.Parameter(torch.ones(3, 3)) + self.b = torch.nn.Parameter(torch.tensor(5.0)) + + def forward(self, x): + return x + self.a * self.b + + class M2(torch.nn.Module): + def forward(self, x): + m = M1() + return m(x) * x + + inps = (torch.randn(3, 3),) + # Strict export segfaults (Issue #128109) + ep = torch.export.export(M2(), inps, strict=False) + self.assertTrue(torch.allclose(ep.module()(*inps), M2()(*inps))) + + self.assertEqual(len(ep.state_dict), 0) + self.assertEqual(len(ep.constants), 1) + + self.assertExpectedInline( + str(ep.graph).strip(), + """\ +graph(): + %c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0] + %x : [num_users=2] = placeholder[target=x] + %ones : [num_users=1] = call_function[target=torch.ops.aten.ones.default](args = ([3, 3],), kwargs = {device: cpu, pin_memory: False}) + %detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%ones,), kwargs = {}) + %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_lifted_tensor_0,), kwargs = {}) + %detach_1 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%lift_fresh_copy,), kwargs = {}) + %detach_2 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%detach_1,), kwargs = {}) + %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%detach, %detach_2), kwargs = {}) + %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %mul), kwargs = {}) + %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {}) + return (mul_1,)""", + ) + + unflattened = unflatten(ep) + self.assertTrue(torch.allclose(unflattened(*inps), M2()(*inps))) + def test_lazy_module_kwargs(self): class LazyModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module): def initialize_parameters(self, *args, **kwargs): diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 9976c4e9beca2..511762c1612cc 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -29,6 +29,7 @@ from torch.utils._traceback import CapturedTraceback import logging from torch._library.fake_class_registry import FakeScriptObject +import warnings from torch.overrides import TorchFunctionMode @@ -921,6 +922,10 @@ def disable_autocast_cache(): torch.set_autocast_cache_enabled(old_value) +class ModuleNotInstalledAsSubmoduleError(NameError): + pass + + class _ModuleStackTracer(PythonKeyTracer): r"""Customized version of PythonKeyTracer that retains module stack information in node.meta["nn_module_stack"]. @@ -998,7 +1003,10 @@ def path_of_module(self, mod: torch.nn.Module) -> str: if isinstance(mod, self.proxy_type): return self.proxy_paths[mod] - return Tracer.path_of_module(self, mod) + try: + return Tracer.path_of_module(self, mod) + except NameError as e: + raise ModuleNotInstalledAsSubmoduleError from e def getattr(self, attr, attr_val, parameter_proxy_cache): if not isinstance(attr_val, torch.nn.Module) or isinstance(attr_val, torch.fx.GraphModule): @@ -1070,7 +1078,17 @@ def call_module(self, m, forward, args, kwargs): # use cases don't need to work with HOO. if isinstance(m, (OptimizedModule, GraphModule)): return forward(*args, **kwargs) - return Tracer.call_module(self, m, forward, args, kwargs) + + try: + return Tracer.call_module(self, m, forward, args, kwargs) + except ModuleNotInstalledAsSubmoduleError as e: + warnings.warn( + f"Unable to find the path of the module {m}. " + "This might be because the module was not properly registered " + "as a submodule, which is not good practice. We will trace " + "through the module without recording stack information." + ) + return forward(*args, **kwargs) def is_leaf_module(self, m, module_qualified_name): From f9de510121be4e9e0f8fd95ed99dd5d877ac8aed Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sat, 18 May 2024 17:00:46 -0700 Subject: [PATCH 10/35] [dynamo] Graph break on set_num_threads (#126623) Fixes #125364 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126623 Approved by: https://github.com/yanboliang --- .../TestNNDeviceTypeCPU.test_conv_empty_input_cpu_complex128} | 0 torch/_dynamo/trace_rules.py | 3 --- 2 files changed, 3 deletions(-) rename test/{dynamo_expected_failures/TestSortAndSelectCPU.test_sort_overflow_cpu_int16 => dynamo_skips/TestNNDeviceTypeCPU.test_conv_empty_input_cpu_complex128} (100%) diff --git a/test/dynamo_expected_failures/TestSortAndSelectCPU.test_sort_overflow_cpu_int16 b/test/dynamo_skips/TestNNDeviceTypeCPU.test_conv_empty_input_cpu_complex128 similarity index 100% rename from test/dynamo_expected_failures/TestSortAndSelectCPU.test_sort_overflow_cpu_int16 rename to test/dynamo_skips/TestNNDeviceTypeCPU.test_conv_empty_input_cpu_complex128 diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 8a2c12ee4e84e..9ac3db8647089 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -1302,9 +1302,6 @@ "torch._C.parse_schema", "torch._C.parse_type_comment", "torch._C.read_vitals", - "torch._C.set_flush_denormal", - "torch._C.set_num_interop_threads", - "torch._C.set_num_threads", "torch._C.set_vital", "torch._C.unify_type_list", "torch._C.vitals_enabled", From 022adf8c5ef616173d9bd912f0f824093d86cc9c Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Sun, 19 May 2024 05:24:08 -0700 Subject: [PATCH 11/35] Fix bug for comptime.get_local for cells/closures (#126637) I wasn't paying enough attention and didn't notice that LOAD_DEREF is defined differently for InliningInstructionTranslator. Match it up with the code there. This also fixes comptime.print(), which was broken, because closing over an argument turned it into a cell rather than a regular local. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/126637 Approved by: https://github.com/yanboliang --- test/dynamo/test_comptime.py | 37 +++++++++++++++++++++++++++++++++ torch/_dynamo/comptime.py | 17 ++++++++++++++- torch/_dynamo/variables/misc.py | 3 ++- 3 files changed, 55 insertions(+), 2 deletions(-) diff --git a/test/dynamo/test_comptime.py b/test/dynamo/test_comptime.py index 6d874a005047b..a14c889a3bce7 100644 --- a/test/dynamo/test_comptime.py +++ b/test/dynamo/test_comptime.py @@ -189,6 +189,43 @@ def _(ctx): """, ) + # Just make sure it doesn't crash + def test_print_direct(self): + cnt = torch._dynamo.testing.CompileCounter() + + @torch._dynamo.optimize(cnt) + def f(x, z): + y = x * 2 + lambda: z + comptime.print(z) + return y + 3 + + f(torch.randn(2), torch.randn(2)) + + # Just make sure it doesn't crash + def test_get_local_closure_variable(self): + global SELF + SELF = self + cnt = torch._dynamo.testing.CompileCounter() + + @torch._dynamo.optimize(cnt) + def f(x): + z = 3 + + def g(): + @comptime + def _(ctx): + r = ctx.get_local("z") + SELF.assertEqual(repr(r), "3") + + comptime.print(z) + return 2 + + y = x * g() + return y + 3 + + f(torch.randn(2)) + def test_print_bt(self): global FILE FILE = StringIO() diff --git a/torch/_dynamo/comptime.py b/torch/_dynamo/comptime.py index 23000c464fdbb..80880588b54e3 100644 --- a/torch/_dynamo/comptime.py +++ b/torch/_dynamo/comptime.py @@ -14,7 +14,9 @@ from torch.fx.experimental.symbolic_shapes import free_symbols from .exc import unimplemented +from .variables import NewCellVariable from .variables.constant import ConstantVariable +from .variables.misc import ClosureVariable from .variables.tensor import SymNodeVariable @@ -146,7 +148,20 @@ def get_local(self, name: str, *, stacklevel=0) -> ComptimeVar: Retrieve the compile-time known information about a local. """ tx = self.__get_tx(stacklevel) - return ComptimeVar(tx.symbolic_locals[name]) + + # This is analogous to LOAD_DEREF + if hasattr(tx, "closure_cells") and name in tx.closure_cells: + cell = tx.closure_cells[name] + if isinstance(cell, ClosureVariable): + return ComptimeVar(tx.output.root_tx.symbolic_locals[cell.name]) + else: + return ComptimeVar(tx.output.side_effects.load_cell(cell)) + else: + r = tx.symbolic_locals[name] + if isinstance(r, NewCellVariable): + return ComptimeVar(tx.output.side_effects.load_cell(r)) + else: + return ComptimeVar(r) def graph_break(self, msg="ComptimeContext.graph_break"): """ diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 3e9495b3c7ca8..02f4f8f47b279 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -218,7 +218,8 @@ def call_function( # TODO: support an expression form as well assert not kwargs - assert len(args) == 1 + # Second argument is runtime lambda, ignored + assert len(args) <= 2 fn = args[0] if isinstance(fn, UserFunctionVariable): fn.get_function()(ComptimeContext(tx)) From 2068dadbe817da521695a86118cb80776dbd6f46 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Mon, 20 May 2024 17:53:44 +0000 Subject: [PATCH 12/35] [torchbench] Add torchao to PT2 Benchmark Runner (#126469) Summary: X-link: https://github.com/pytorch/benchmark/pull/2268 Support torchao performance and accuracy tests in PT2 Benchmark Runner, using the inductor backend as the baseline. Test Plan: ``` $ buck2 run mode/opt //caffe2/benchmarks/dynamo:torchbench -- --only BERT_pytorch --bfloat16 --quantization int8dynamic --performance --inference --print-memory loading model: 0it [00:50, ?it/s] cuda eval BERT_pytorch memory: eager: 0.75 GB, dynamo: 0.75 GB, ratio: 1.00 running benchmark: 100% 1.003x ``` Reviewed By: jerryzh168 Differential Revision: D57463273 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126469 Approved by: https://github.com/huydhn --- benchmarks/dynamo/common.py | 29 +++++++++++++++++++ benchmarks/dynamo/torchao.py | 54 ++++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+) create mode 100644 benchmarks/dynamo/torchao.py diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 6ea7a31a39150..f40f40396992e 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -3485,6 +3485,18 @@ def get_example_inputs(self): action="store_true", help="Measure speedup with TorchInductor", ) + group.add_argument( + "--quantization", + choices=[ + "int8dynamic", + "int8weightonly", + "int4weightonly", + "autoquant", + "noquant", + ], + default=None, + help="Measure speedup of torchao quantization with TorchInductor baseline", + ) group.add_argument( "--export", action="store_true", @@ -3679,6 +3691,9 @@ def run(runner, args, original_dir=None): if args.inductor: assert args.backend is None args.backend = "inductor" + if args.quantization: + assert args.backend is None + args.backend = "torchao" if args.dynamic_batch_only: args.dynamic_shapes = True torch._dynamo.config.assume_static_by_default = True @@ -3957,6 +3972,20 @@ def run(runner, args, original_dir=None): # AOTInductor doesn't support control flow yet runner.skip_models.update(runner.skip_models_due_to_control_flow) + elif args.backend == "torchao": + assert "cuda" in args.devices, "Quantization requires CUDA device." + assert args.bfloat16, "Quantization requires dtype bfloat16." + from .torchao import setup_baseline, torchao_optimize_ctx + + setup_baseline() + baseline_ctx = functools.partial( + torch.compile, + backend="inductor", + fullgraph=args.nopython, + mode=args.inductor_compile_mode, + ) + runner.model_iter_fn = baseline_ctx(runner.model_iter_fn) + optimize_ctx = torchao_optimize_ctx(args.quantization) else: optimize_ctx = torch._dynamo.optimize(args.backend, nopython=args.nopython) experiment = speedup_experiment diff --git a/benchmarks/dynamo/torchao.py b/benchmarks/dynamo/torchao.py new file mode 100644 index 0000000000000..29e7d55d76ce1 --- /dev/null +++ b/benchmarks/dynamo/torchao.py @@ -0,0 +1,54 @@ +from typing import Any, Callable + +import torch + + +def setup_baseline(): + torch._dynamo.epilogue_fusion = False + torch._dynamo.config.automatic_dynamic_shapes = False + torch._dynamo.config.force_parameter_static_shapes = False + torch._dynamo.config.cache_size_limit = 10000 + torch._inductor.config.force_fuse_int_mm_with_mul = True + torch._inductor.config.use_mixed_mm = True + + +def torchao_optimize_ctx(quantization: str): + import torchao + from torchao.quantization import ( + change_linear_weights_to_int4_woqtensors, + change_linear_weights_to_int8_dqtensors, + change_linear_weights_to_int8_woqtensors, + ) + + def inner(model_iter_fn: Callable): + def _torchao_apply(module: torch.nn.Module, example_inputs: Any): + if getattr(module, "_quantized", None) is None: + if quantization == "int8dynamic": + change_linear_weights_to_int8_dqtensors(module) + elif quantization == "int8weightonly": + change_linear_weights_to_int8_woqtensors(module) + elif quantization == "int4weightonly": + change_linear_weights_to_int4_woqtensors(module) + elif quantization == "autoquant": + torchao.autoquant(module, error_on_unseen=False) + if isinstance(example_inputs, dict): + module(**example_inputs) + else: + module(*example_inputs) + from torchao.quantization.autoquant import AUTOQUANT_CACHE + + assert ( + len(AUTOQUANT_CACHE) > 0 + ), f"Err: found no autoquantizable layers in model {type(module)}, stopping autoquantization" + elif quantization == "noquant": + pass + else: + raise AssertionError( + f"Unsupposed quantization mode {quantization}." + ) + setattr(module, "_quantized", True) # noqa: B010 + model_iter_fn(module, example_inputs) + + return _torchao_apply + + return inner From 11c2d127ec8ff7dc6ca50edf778528bd8acf9452 Mon Sep 17 00:00:00 2001 From: Mu-Chu Lee Date: Mon, 20 May 2024 18:16:00 +0000 Subject: [PATCH 13/35] [AOTInductor] Add config to allow buffer mutation (#126584) Summary: Add an additional config to allow buffer mutation. For data that's greater than 2GB, we would need to set it as read-only, otherwise overflow would occur. This is a temporary solution since it won't handle cases that requires mutable data greater than 2GB. Test Plan: Included in commit. Differential Revision: D57514729 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126584 Approved by: https://github.com/chenyang78 --- test/inductor/test_aot_inductor.py | 31 ++++++++++++++++++------------ torch/_inductor/codecache.py | 15 +++++++++++---- torch/_inductor/config.py | 3 +++ 3 files changed, 33 insertions(+), 16 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 49afa50b78ac2..f1b153592db4b 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -257,20 +257,24 @@ def forward(self, x): example_inputs = (torch.randn(32, 64, device=self.device),) self.check_model(Model(), example_inputs) - def test_large(self): + def test_large_weight(self): class Model(torch.nn.Module): def __init__(self): super().__init__() - self.linear = torch.nn.Linear(512, 250112) + self.linear = torch.nn.Linear(2048, 262144) def forward(self, x, y): return x + self.linear(y) example_inputs = ( - torch.randn(1, 250112, device=self.device), - torch.randn(1, 512, device=self.device), + torch.randn(1, 262144, device=self.device), + torch.randn(1, 2048, device=self.device), ) - self.check_model(Model(), example_inputs) + + # We only test compilation since we often get OOM running in CI. + model = Model() + model = model.to(self.device) + AOTIRunnerUtil.compile(model, example_inputs) def test_large_mmaped_weights(self): class Model(torch.nn.Module): @@ -1208,8 +1212,9 @@ def forward(self, x): return self.foo + x example_inputs = (torch.rand(4, 4, device=self.device),) - torch._export.aot_compile(Model(self.device), example_inputs) - self.check_model(Model(self.device), example_inputs) + with config.patch({"aot_inductor.allow_buffer_mutation": True}): + torch._export.aot_compile(Model(self.device), example_inputs) + self.check_model(Model(self.device), example_inputs) def test_non_tensor_input(self): def fn(a, b, alpha=1.0): @@ -1241,8 +1246,9 @@ def forward(self, x): self.foo[5] = self.bar[0] return x + self.bar, x * self.foo - example_inputs = (torch.randn(10, device=self.device),) - self.check_model(Model(self.device), example_inputs) + with config.patch({"aot_inductor.allow_buffer_mutation": True}): + example_inputs = (torch.randn(10, device=self.device),) + self.check_model(Model(self.device), example_inputs) def test_buffer_mutation_3(self): class KVCache(torch.nn.Module): @@ -1282,7 +1288,8 @@ def forward(self, inp_pos, k, v): torch.randn(1, 6, 1, 48, device=self.device), torch.randn(1, 6, 1, 48, device=self.device), ) - self.check_model(Model(self.device), example_inputs) + with config.patch({"aot_inductor.allow_buffer_mutation": True}): + self.check_model(Model(self.device), example_inputs) @requires_multigpu() def test_replicate_on_devices(self): @@ -2975,7 +2982,7 @@ def fail_non_abi_compatible_cuda(is_skip=False): "test_addmm_multiple_dynamic": fail_cuda(is_skip=True), "test_bmm_multiple_dynamic": fail_cuda(is_skip=True), "test_convolution": fail_cuda(is_skip=True), - "test_large": fail_cuda(is_skip=True), + "test_large_weight": fail_cuda(is_skip=True), "test_large_mmaped_weights": fail_cuda(is_skip=True), "test_missing_cubin": fail_cuda(is_skip=True), "test_multi_device": fail_cuda(is_skip=True), @@ -3020,7 +3027,7 @@ def fail_non_abi_compatible_cuda(is_skip=False): "test_constant_folding": fail_minimal_arrayref_interface(is_skip=True), "test_convolution": fail_minimal_arrayref_interface(is_skip=True), "test_empty_graph": fail_minimal_arrayref_interface(is_skip=True), - "test_large": fail_minimal_arrayref_interface(is_skip=True), + "test_large_weight": fail_minimal_arrayref_interface(is_skip=True), "test_large_mmaped_weights": fail_minimal_arrayref_interface(is_skip=True), "test_misc_1": fail_minimal_arrayref_interface(is_skip=True), "test_missing_output": fail_minimal_arrayref_interface(is_skip=True), diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 70b4671431115..528cebdc2ee57 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1924,12 +1924,19 @@ def _compile_consts_linux(consts: bytes) -> str: run_command_and_check(cmd) log.debug("aot constant binary command: %s", cmd) - # .data section is between .text and .bss. When the size of .data is large, - # during the linking, the relocation of .text against .bss may overflow. - # Rename it to .ldata so that it won't be in between the .text and .bss section + if config.aot_inductor.allow_buffer_mutation: + # .data section is between .text and .bss. When the size of .data is large, + # during the linking, the relocation of .text against .bss may overflow. + # Rename it to .ldata so that it won't be in between the .text and .bss section + rename_data = " .data=.ldata" + else: + # if no buffer mutation is needed, we could instead set the data region + # as read-only (i.e. .lrodata) which could accomodate larger size of data + # to be linked. + rename_data = " .data=.lrodata,alloc,load,readonly,data,contents" cmd = ( f"{objcopy_command} --rename-section" - " .data=.ldata" + f"{rename_data}" " --set-section-alignment .data=64" # following the gAlignment of CPU in c10/core/alignment.h f" {consts_o} {consts_o}" ) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index db8a6d9ae3b63..efe1f24b801de 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -744,6 +744,9 @@ class aot_inductor: # rather than embedded into the data section. Needed to support 1B+ parameter models force_mmap_weights: bool = False + # flag to allow buffer mutation. This would remove the read-only property from buffers. + allow_buffer_mutation: bool = False + class cuda: # CUDA arch to use for CUDA template kernel compilation. From 2813f0672aebd5207adb4fd8ea48f8953ba26740 Mon Sep 17 00:00:00 2001 From: Yueming Hao Date: Mon, 20 May 2024 19:10:41 +0000 Subject: [PATCH 14/35] fix huggingface models input issue in torchbench (#126579) Fixes https://github.com/pytorch/benchmark/issues/2263. According to https://github.com/pytorch/pytorch/blob/main/benchmarks/dynamo/common.py#L509, example_inputs are formatted as dictionaries for HuggingFace models. However, this forward_pass function passes all inputs to mod with *, which may only pass the input_ids key in HuggingFace model's example inputs. To reproduce, run the following command. ```bash python pytorch/benchmarks/dynamo/torchbench.py --performance --inference -dcuda --only=hf_Bert --output=torchbench_inference.csv ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/126579 Approved by: https://github.com/xuzhao9 --- benchmarks/dynamo/torchbench.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/benchmarks/dynamo/torchbench.py b/benchmarks/dynamo/torchbench.py index 3086bddc4bb5b..2a9437e08b698 100755 --- a/benchmarks/dynamo/torchbench.py +++ b/benchmarks/dynamo/torchbench.py @@ -423,13 +423,19 @@ def compute_loss(self, pred): def forward_pass(self, mod, inputs, collect_outputs=True): with self.autocast(**self.autocast_arg): - return mod(*inputs) + if isinstance(inputs, dict): + return mod(**inputs) + else: + return mod(*inputs) def forward_and_backward_pass(self, mod, inputs, collect_outputs=True): cloned_inputs = clone_inputs(inputs) self.optimizer_zero_grad(mod) with self.autocast(**self.autocast_arg): - pred = mod(*cloned_inputs) + if isinstance(clone_inputs, dict): + pred = mod(**cloned_inputs) + else: + pred = mod(*cloned_inputs) loss = self.compute_loss(pred) self.grad_scaler.scale(loss).backward() self.optimizer_step() From 8bca0847c2c2a80b5dfab6dd554c7dd68199a3b3 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 20 May 2024 20:15:08 +0000 Subject: [PATCH 15/35] Revert "[TD] Upload names of failures to s3 for pytest cache (#126315)" This reverts commit 655038687afd19a4a4c9371b77ff046fd6c84be1. Reverted https://github.com/pytorch/pytorch/pull/126315 on behalf of https://github.com/clee2000 due to broke inductor ([comment](https://github.com/pytorch/pytorch/pull/126315#issuecomment-2121133045)) --- .github/scripts/pytest_caching_utils.py | 29 ------------------- test/run_test.py | 6 ---- tools/stats/import_test_stats.py | 15 ---------- .../testing/do_target_determination_for_s3.py | 2 -- .../heuristics/previously_failed_in_pr.py | 29 +------------------ 5 files changed, 1 insertion(+), 80 deletions(-) diff --git a/.github/scripts/pytest_caching_utils.py b/.github/scripts/pytest_caching_utils.py index e1a581c2d3095..b2a71fb4b8e13 100644 --- a/.github/scripts/pytest_caching_utils.py +++ b/.github/scripts/pytest_caching_utils.py @@ -18,7 +18,6 @@ PYTEST_CACHE_DIR_NAME = ".pytest_cache" BUCKET = "gha-artifacts" LASTFAILED_FILE_PATH = Path("v/cache/lastfailed") -TD_HEURISTIC_PREVIOUSLY_FAILED_ADDITIONAL = "previous_failures_additional.json" # Temp folders ZIP_UPLOAD = "zip-upload" @@ -192,10 +191,6 @@ def _merge_pytest_caches( pytest_cache_dir_to_merge_from, pytest_cache_dir_to_merge_into ) - _merge_additional_failures_files( - pytest_cache_dir_to_merge_from, pytest_cache_dir_to_merge_into - ) - def _merge_lastfailed_files(source_pytest_cache: Path, dest_pytest_cache: Path) -> None: # Simple cases where one of the files doesn't exist @@ -237,27 +232,3 @@ def _merged_lastfailed_content( del to_lastfailed[""] return to_lastfailed - - -def _merge_additional_failures_files( - source_pytest_cache: Path, dest_pytest_cache: Path -) -> None: - # Simple cases where one of the files doesn't exist - source_lastfailed_file = ( - source_pytest_cache / TD_HEURISTIC_PREVIOUSLY_FAILED_ADDITIONAL - ) - dest_lastfailed_file = dest_pytest_cache / TD_HEURISTIC_PREVIOUSLY_FAILED_ADDITIONAL - - if not source_lastfailed_file.exists(): - return - if not dest_lastfailed_file.exists(): - copy_file(source_lastfailed_file, dest_lastfailed_file) - return - - # Both files exist, so we need to merge them - from_lastfailed = load_json_file(source_lastfailed_file) - to_lastfailed = load_json_file(dest_lastfailed_file) - merged_content = list(set(from_lastfailed + to_lastfailed)) - - # Save the results - write_json_file(dest_lastfailed_file, merged_content) diff --git a/test/run_test.py b/test/run_test.py index d43a396f1441e..71ab08199f7a7 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -59,9 +59,6 @@ ) from tools.testing.do_target_determination_for_s3 import import_results from tools.testing.target_determination.gen_artifact import gen_ci_artifact -from tools.testing.target_determination.heuristics.previously_failed_in_pr import ( - gen_additional_test_failures_file, -) from tools.testing.target_determination.heuristics.utils import get_pr_number from tools.testing.test_run import TestRun @@ -1798,9 +1795,6 @@ def __str__(self): **test_stats, }, ) - gen_additional_test_failures_file( - [test.test_file for test, _ in all_failures] - ) if len(all_failures): for _, err in all_failures: diff --git a/tools/stats/import_test_stats.py b/tools/stats/import_test_stats.py index 513edb12fcfe6..a256c3ea04c34 100644 --- a/tools/stats/import_test_stats.py +++ b/tools/stats/import_test_stats.py @@ -28,7 +28,6 @@ def get_disabled_issues() -> List[str]: TD_HEURISTIC_PROFILING_FILE = "td_heuristic_profiling.json" TD_HEURISTIC_HISTORICAL_EDITED_FILES = "td_heuristic_historical_edited_files.json" TD_HEURISTIC_PREVIOUSLY_FAILED = "previous_failures.json" -TD_HEURISTIC_PREVIOUSLY_FAILED_ADDITIONAL = "previous_failures_additional.json" FILE_CACHE_LIFESPAN_SECONDS = datetime.timedelta(hours=3).seconds @@ -166,20 +165,6 @@ def copy_pytest_cache() -> None: ) -def copy_additional_previous_failures() -> None: - original_path = ( - REPO_ROOT / ".pytest_cache" / TD_HEURISTIC_PREVIOUSLY_FAILED_ADDITIONAL - ) - if not original_path.exists(): - return - shutil.copyfile( - original_path, - REPO_ROOT - / ADDITIONAL_CI_FILES_FOLDER - / TD_HEURISTIC_PREVIOUSLY_FAILED_ADDITIONAL, - ) - - def get_from_test_infra_generated_stats( from_file: str, to_file: str, failure_explanation: str ) -> Dict[str, Any]: diff --git a/tools/testing/do_target_determination_for_s3.py b/tools/testing/do_target_determination_for_s3.py index 32ea85b980214..4b004801fc747 100644 --- a/tools/testing/do_target_determination_for_s3.py +++ b/tools/testing/do_target_determination_for_s3.py @@ -8,7 +8,6 @@ sys.path.insert(0, str(REPO_ROOT)) from tools.stats.import_test_stats import ( - copy_additional_previous_failures, copy_pytest_cache, get_td_heuristic_historial_edited_files_json, get_td_heuristic_profiling_json, @@ -52,7 +51,6 @@ def main() -> None: get_td_heuristic_historial_edited_files_json() get_td_heuristic_profiling_json() copy_pytest_cache() - copy_additional_previous_failures() aggregated_heuristics = get_test_prioritizations(selected_tests) diff --git a/tools/testing/target_determination/heuristics/previously_failed_in_pr.py b/tools/testing/target_determination/heuristics/previously_failed_in_pr.py index 81e808d063c3a..8d15d7537e915 100644 --- a/tools/testing/target_determination/heuristics/previously_failed_in_pr.py +++ b/tools/testing/target_determination/heuristics/previously_failed_in_pr.py @@ -6,7 +6,6 @@ from tools.stats.import_test_stats import ( ADDITIONAL_CI_FILES_FOLDER, TD_HEURISTIC_PREVIOUSLY_FAILED, - TD_HEURISTIC_PREVIOUSLY_FAILED_ADDITIONAL, ) from tools.testing.target_determination.heuristics.interface import ( @@ -26,7 +25,7 @@ def __init__(self, **kwargs: Dict[str, Any]): super().__init__(**kwargs) def get_prediction_confidence(self, tests: List[str]) -> TestPrioritizations: - critical_tests = get_previous_failures() | read_additional_test_failures_file() + critical_tests = get_previous_failures() return TestPrioritizations( tests, {TestRun(test): 1 for test in critical_tests if test in tests} ) @@ -55,29 +54,3 @@ def _parse_prev_failing_test_files(last_failed_tests: Dict[str, bool]) -> Set[st prioritized_tests.add(test_file) return prioritized_tests - - -def gen_additional_test_failures_file(tests: List[str]) -> None: - # Segfaults usually result in no xml and some tests don't run through pytest - # (ex doctests). In these cases, there will be no entry in the pytest - # cache, so we should generate a separate file for them and upload it to s3 - # along with the pytest cache - with open( - REPO_ROOT / ".pytest_cache" / TD_HEURISTIC_PREVIOUSLY_FAILED_ADDITIONAL, "w" - ) as f: - json.dump(tests, f, indent=2) - - -def read_additional_test_failures_file() -> Set[str]: - path = ( - REPO_ROOT - / ADDITIONAL_CI_FILES_FOLDER - / TD_HEURISTIC_PREVIOUSLY_FAILED_ADDITIONAL - ) - if not os.path.exists(path): - print(f"could not find path {path}") - return set() - with open(path) as f: - s = set(json.load(f)) - print(f"additional failures: {s}") - return s From d28868c7e8bcd41c9219f099aa5f7a5332c912fd Mon Sep 17 00:00:00 2001 From: jhavukainen <104022140+jhavukainen@users.noreply.github.com> Date: Mon, 20 May 2024 20:23:53 +0000 Subject: [PATCH 16/35] Change skipIfs to xfails in test_mps.py for test_isin (#125412) Follow-up to #124896 to move the added test to use expectedFailure instead of skip. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125412 Approved by: https://github.com/kulinseth --- .../native/mps/operations/TensorCompare.mm | 6 +++-- test/test_mps.py | 24 +++++++++++++------ 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/TensorCompare.mm b/aten/src/ATen/native/mps/operations/TensorCompare.mm index f378af1326a73..4da5c302214d1 100644 --- a/aten/src/ATen/native/mps/operations/TensorCompare.mm +++ b/aten/src/ATen/native/mps/operations/TensorCompare.mm @@ -276,8 +276,6 @@ static void isin_Tensor_Tensor_out_mps(const Tensor& elements, bool invert, const Tensor& out, string op_name) { - TORCH_CHECK(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS), - "isin_Tensor_Tensor_out supported on MPS from MacOs_14_0 onwards"); if (elements.numel() == 0) { return; } @@ -295,6 +293,10 @@ static void isin_Tensor_Tensor_out_mps(const Tensor& elements, TORCH_CHECK(elements.is_mps() && test_elements.is_mps()); TORCH_CHECK(elements.dtype() == test_elements.dtype()); + TORCH_CHECK( + !(!is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) && !supportedFloatingType(elements.scalar_type())), + "isin_Tensor_Tensor_out only works on floating types on MPS for pre MacOS_14_0. Received dtype: ", + elements.scalar_type()); @autoreleasepool { string key = diff --git a/test/test_mps.py b/test/test_mps.py index 24c4e2d45e48e..cbf8874e1c220 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -61,10 +61,17 @@ ) ) +def xfailIf(condition): + def wrapper(func): + if condition: + return unittest.expectedFailure(func) + else: + return func + return wrapper + def xfailIfMacOS14_4Plus(func): return unittest.expectedFailure(func) if product_version > 14.3 else func # noqa: F821 - def mps_ops_grad_modifier(ops): XFAILLIST_GRAD = { @@ -901,9 +908,9 @@ def mps_ops_modifier(ops): 'fft.rfft2': None, 'fft.rfftn': None, 'stft': None, - # Error in TestConsistencyCPU.test_output_match_isin_cpu_int32, + # Error in TestConsistencyCPU.test_output_match_isin_cpu fails for integers, # not reproducible in later OS. Added assert to op if used in < 14.0 - 'isin': None, + 'isin': [torch.int64, torch.int32, torch.int16, torch.uint8, torch.int8], }) UNDEFINED_XFAILLIST = { @@ -8218,7 +8225,6 @@ def helper(dtype): [helper(dtype) for dtype in [torch.float32, torch.float16, torch.int32, torch.int16, torch.uint8, torch.int8, torch.bool]] - @unittest.skipIf(product_version < 14.0, "Skipped on MacOS < 14.0") def test_isin(self): def helper(dtype): shapes = [([2, 5], [3, 5, 2]), ([10, 3, 5], [20, 1, 3]), @@ -8237,15 +8243,19 @@ def helper(dtype): B_mps = B.clone().detach().to('mps') cpu_ref = torch.isin(A, B, invert=inverted) - if dtype is torch.float16: + if dtype in [torch.float16, torch.bfloat16]: cpu_ref.type(dtype) mps_out = torch.isin(A_mps, B_mps, invert=inverted) self.assertEqual(mps_out, cpu_ref) - [helper(dtype) for dtype in [torch.float32, torch.float16, torch.int32, torch.int16, torch.uint8, torch.int8]] + dtypes = [torch.float32, torch.float16, torch.bfloat16, torch.int32, torch.int16, torch.uint8, torch.int8] + if product_version < 14.0: + # Int types expected to fail on MacOS < 14.0 + dtypes = [torch.float32, torch.float16, torch.bfloat16] + + [helper(dtype) for dtype in dtypes] - @unittest.skipIf(product_version < 14.0, "Skipped on MacOS < 14.0") def test_isin_asserts(self): A = torch.randn(size=[1, 4], device='mps', dtype=torch.float32) B = torch.randn(size=[1, 4], device='mps', dtype=torch.float16) From 3d4f1c3083e61349840f42b13656d13455aab25d Mon Sep 17 00:00:00 2001 From: angelayi Date: Mon, 20 May 2024 20:50:11 +0000 Subject: [PATCH 17/35] [export] Make error name private (#126715) Fixes CI Pull Request resolved: https://github.com/pytorch/pytorch/pull/126715 Approved by: https://github.com/clee2000 --- torch/fx/experimental/proxy_tensor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 511762c1612cc..cdd9995bbcd36 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -922,7 +922,7 @@ def disable_autocast_cache(): torch.set_autocast_cache_enabled(old_value) -class ModuleNotInstalledAsSubmoduleError(NameError): +class _ModuleNotInstalledAsSubmoduleError(NameError): pass @@ -1006,7 +1006,7 @@ def path_of_module(self, mod: torch.nn.Module) -> str: try: return Tracer.path_of_module(self, mod) except NameError as e: - raise ModuleNotInstalledAsSubmoduleError from e + raise _ModuleNotInstalledAsSubmoduleError from e def getattr(self, attr, attr_val, parameter_proxy_cache): if not isinstance(attr_val, torch.nn.Module) or isinstance(attr_val, torch.fx.GraphModule): @@ -1081,7 +1081,7 @@ def call_module(self, m, forward, args, kwargs): try: return Tracer.call_module(self, m, forward, args, kwargs) - except ModuleNotInstalledAsSubmoduleError as e: + except _ModuleNotInstalledAsSubmoduleError as e: warnings.warn( f"Unable to find the path of the module {m}. " "This might be because the module was not properly registered " From acfe237a71af609e837a34bb38048aa8acb8eb4d Mon Sep 17 00:00:00 2001 From: Yueming Hao Date: Mon, 20 May 2024 20:57:50 +0000 Subject: [PATCH 18/35] Fix C++ compilation error for tensor array in abi_compatible mode (#126412) Fixes #122048 There is a compilation error https://github.com/pytorch/pytorch/issues/122048 when the element type in an array is tensor. It is because `val_to_arg_str does` not take arg type as input, and always generate an int array. This PR change the underlying `codegen_int_array_var` to `codegen_var_array` by adding type checks and corresponding code generations. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126412 Approved by: https://github.com/desertfire --- test/inductor/test_aot_inductor.py | 10 ++-- torch/_inductor/codegen/cpp_wrapper_cpu.py | 67 ++++++++++++++++++++-- torch/_inductor/codegen/wrapper.py | 4 +- torch/_inductor/ir.py | 6 +- 4 files changed, 70 insertions(+), 17 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index f1b153592db4b..5545cab3c0788 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -2879,8 +2879,10 @@ def fail_non_abi_compatible_cuda(is_skip=False): "test_duplicate_constant_folding": fail_with_and_without_stack_allocation( is_skip=True ), - "test_dup_unbacked_sym_decl": fail_with_and_without_stack_allocation(), - "test_dup_unbacked_sym_decl_with_refinement": fail_with_and_without_stack_allocation(), + "test_dup_unbacked_sym_decl": fail_minimal_arrayref_interface(is_skip=True), + "test_dup_unbacked_sym_decl_with_refinement": fail_minimal_arrayref_interface( + is_skip=True + ), "test_dynamic_cat": fail_minimal_arrayref_interface(), # https://github.com/pytorch/pytorch/issues/122978 "test_dynamic_scalar": fail_stack_allocation(is_skip=True), @@ -2957,8 +2959,6 @@ def fail_non_abi_compatible_cuda(is_skip=False): CUDA_TEST_FAILURES = { # test_failures, xfail by default, set is_skip=True to skip - "test_dup_unbacked_sym_decl": fail_abi_compatible_cuda(), - "test_dup_unbacked_sym_decl_with_refinement": fail_abi_compatible_cuda(), "test_large_grid": fail_cuda(), "test_normal_functional": fail_abi_compatible_cuda(), # There is a double-free issue which will be fixed in another PR @@ -2977,8 +2977,6 @@ def fail_non_abi_compatible_cuda(is_skip=False): if TEST_WITH_ROCM: CUDA_TEST_FAILURES.update( { - "test_dup_unbacked_sym_decl": fail_cuda(is_skip=True), - "test_dup_unbacked_sym_decl_with_refinement": fail_cuda(is_skip=True), "test_addmm_multiple_dynamic": fail_cuda(is_skip=True), "test_bmm_multiple_dynamic": fail_cuda(is_skip=True), "test_convolution": fail_cuda(is_skip=True), diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 9595f1da6f957..38df2331315ed 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -46,6 +46,10 @@ def __init__(self): self.supports_intermediate_hooks = False self.outputs_need_copy = set() self.kernel_callsite_id = count() + self.var_array_id = ( + count() + ) # for different types of local array variable declarations + self.declared_var_array_vars = set() self.int_array_id = count() # for int array local variable declarations self.declared_int_array_vars = set() self.tmp_tensor_id = count() # for tmp tensor local variable declarations @@ -1511,6 +1515,43 @@ def codegen_int_array_var( writer.writeline(f"const {ctype} {var}[] = {int_array};") return var + @functools.lru_cache(None) + def codegen_var_array( + self, + var_array: str, + writer=None, + known_statically=False, + graph=None, # for per-graph caching + type_hint=None, # ['int64_t', 'tensor', 'bool'] + ): + # Because the memory planning is done in two passes (see the implementation + # of self.generate), the writeline behavior is different in the two passes. + # As a result, the emitted int array declarations may appear in a later + # position of the generated code, so the second pass codegen should not + # reuse int array declarations generated in the first pass + if writer is None: + # The first pass codegen uses `self` as the writer + writer = self + if not type_hint or type_hint in ["bool", "int64_t"]: + return self.codegen_int_array_var( + var_array, + writer, + known_statically, + graph, + is_bool=type_hint == "bool", + ) + + var = f"var_array_{next(self.var_array_id)}" + assert type_hint == "tensor" + ctype = "AtenTensorHandle*" + if var not in self.declared_var_array_vars: + self.declared_var_array_vars.add(var) + if known_statically: + writer.writeline(f"static constexpr {ctype} {var}[] = {var_array};") + else: + writer.writeline(f"const {ctype} {var}[] = {var_array};") + return var + def make_buffer_allocation(self, buffer): return self.make_allocation( buffer.get_name(), @@ -2243,7 +2284,7 @@ def generate_reset_kernel_saved_flags(self): def generate_save_uncompiled_kernels(self): pass - def val_to_cpp_arg_str(self, type_, val) -> str: + def val_to_cpp_arg_str(self, val, type_) -> str: if config.abi_compatible and isinstance(type_, torch.OptionalType): if val is None: return "0" # nullptr is not available in C @@ -2280,9 +2321,9 @@ def val_to_cpp_arg_str(self, type_, val) -> str: self.writeline(f"AtenTensorHandle {var_name} = {base_handle}.get();") return f"&{var_name}" - return self.val_to_arg_str(val) + return self.val_to_arg_str(val, type_) - def val_to_arg_str(self, val) -> str: + def val_to_arg_str(self, val, type_=None) -> str: if val is None: # When None is passed as an argument, it represents an optional that does not contain a value. if config.abi_compatible: @@ -2317,14 +2358,28 @@ def val_to_arg_str(self, val) -> str: if config.abi_compatible: assert len(val) > 0, "Empty array is not supported in C" static = self.is_statically_known_list_of_ints(val) + type_hint = "bool" if isinstance(val[0], bool) else "int64_t" + if ( + type_ is not None + and isinstance(type_, torch._C.ListType) + and isinstance(type_.getElementType(), torch._C.OptionalType) + and isinstance( + type_.getElementType().getElementType(), torch._C.TensorType + ) + ): + type_hint = "tensor" + tmp_arg_list = "" + for x in val: + tmp_arg_list += f"&{x}_handle, " + result = f"{{{tmp_arg_list}}}" # Need to pass the array length because we can't use std::vector - int_var_array = self.codegen_int_array_var( + var_array = self.codegen_var_array( result, known_statically=static, graph=self.get_codegened_graph(), - is_bool=isinstance(val[0], bool), + type_hint=type_hint, ) - return f"{int_var_array}, {len(val)}" + return f"{var_array}, {len(val)}" else: return result else: diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 5d9a11de149e0..ff2fe1ec87cc0 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -1412,10 +1412,10 @@ def writelines(self, lines): def enter_context(self, ctx): self.lines.append(LineContext(ctx)) - def val_to_cpp_arg_str(self, type_, val) -> str: + def val_to_cpp_arg_str(self, val, type_) -> str: raise NotImplementedError - def val_to_arg_str(self, s): + def val_to_arg_str(self, s, type_=None): from torch.utils._triton import dtype_to_string, has_triton_package if has_triton_package(): diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 689877ba69281..b6345eee6d87a 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -4405,7 +4405,7 @@ def codegen_args(self): type_ = self.arg_properties[i].get("type") args.append( V.graph.wrapper_code.val_to_cpp_arg_str( # type: ignore[arg-type] - type_, x + x, type_ ) ) else: @@ -4440,7 +4440,7 @@ def codegen_kwargs(self, skip_out=False): ) kwargs.append( V.graph.wrapper_code.val_to_cpp_arg_str( # type: ignore[arg-type] - type_, v + v, type_ ) ) else: @@ -5408,7 +5408,7 @@ def __repr__(self): if V.graph.cpp_wrapper and isinstance(self.op_overload, torch._ops.OpOverload): args = self.fill_non_provided_args(args, kwargs) args = [ - V.graph.wrapper_code.val_to_cpp_arg_str(param.real_type, x) + V.graph.wrapper_code.val_to_cpp_arg_str(x, param.real_type) for param, x in zip(self.op_overload._schema.arguments, args) ] else: From 74b053d7c4184b2a8e8732777e917dbb4d6788a6 Mon Sep 17 00:00:00 2001 From: Wei Han Date: Mon, 20 May 2024 22:17:52 +0000 Subject: [PATCH 19/35] Pass model path to observer (#126503) Summary: Passing model path to observer so that they can get additional info if needed. Test Plan: contbuild & OSS CI Differential Revision: D57475129 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126503 Approved by: https://github.com/kirklandsign --- torch/csrc/jit/mobile/import.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch/csrc/jit/mobile/import.cpp b/torch/csrc/jit/mobile/import.cpp index 96ff6c88779d9..da7b87bae6110 100644 --- a/torch/csrc/jit/mobile/import.cpp +++ b/torch/csrc/jit/mobile/import.cpp @@ -651,6 +651,10 @@ mobile::Module _load_for_mobile( std::optional device, ExtraFilesMap& extra_files, uint64_t module_load_options) { + auto observer = torch::observerConfig().getModuleObserver(); + if (observer) { + extra_files.insert(std::make_pair("model_path", filename)); + } auto format = getFileFormat(filename); if (format == FileFormat::FlatbufferFileFormat) { From 6f1935b0b54e551478f691f3192afb065e399062 Mon Sep 17 00:00:00 2001 From: Alexander Kurakin Date: Mon, 20 May 2024 22:20:31 +0000 Subject: [PATCH 20/35] doc: `torch.utils.data.Sampler`: `__len__` is optional (#125938) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/125938 Approved by: https://github.com/andrewkho, https://github.com/xmfan --- torch/utils/data/sampler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/utils/data/sampler.py b/torch/utils/data/sampler.py index 0f100f1858419..4c4c967ef9a9a 100644 --- a/torch/utils/data/sampler.py +++ b/torch/utils/data/sampler.py @@ -19,8 +19,8 @@ class Sampler(Generic[T_co]): r"""Base class for all Samplers. Every Sampler subclass has to provide an :meth:`__iter__` method, providing a - way to iterate over indices or lists of indices (batches) of dataset elements, and a :meth:`__len__` method - that returns the length of the returned iterators. + way to iterate over indices or lists of indices (batches) of dataset elements, + and may provide a :meth:`__len__` method that returns the length of the returned iterators. Args: data_source (Dataset): This argument is not used and will be removed in 2.2.0. From 14dc8d4f637283fbed35f1a6a65a9fac21ff7876 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Mon, 20 May 2024 10:37:31 -0700 Subject: [PATCH 21/35] Protect codecache against cache failures (#126696) When there's a manifold, memcache or filesystem related issues or network outages, we should not completely fail to compile but instead fallback to cold start. Differential Revision: [D57573835](https://our.internmc.facebook.com/intern/diff/D57573835/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126696 Approved by: https://github.com/aorenste --- torch/_inductor/codecache.py | 73 +++++++++++++++++++++++------------- 1 file changed, 46 insertions(+), 27 deletions(-) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 528cebdc2ee57..6efa7e0db572c 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -773,11 +773,23 @@ def iterate_over_candidates() -> Generator[CompiledFxGraph, None, None]: subdir = FxGraphCache._get_tmp_dir_for_key(key) if os.path.exists(subdir): for path in sorted(os.listdir(subdir)): - with open(os.path.join(subdir, path), "rb") as f: - yield pickle.load(f) + try: + with open(os.path.join(subdir, path), "rb") as f: + yield pickle.load(f) + except Exception: + log.warning( + "fx graph cache unable to load compiled graph", + exc_info=True, + ) + if remote_cache: - if (data := remote_cache.get(key)) is not None: - yield pickle.loads(data) + try: + if (data := remote_cache.get(key)) is not None: + yield pickle.loads(data) + except Exception: + log.warning( + "fx graph cache unable to load compiled graph", exc_info=True + ) # Iterate over any entries in the subdir for this key and evaluate # their guards to determine whether there's a hit. @@ -890,32 +902,39 @@ def _save_graph( try: content = pickle.dumps(disk_compiled_graph) - except Exception as e: - log.debug("fx graph cache unable to serialize compiled graph: %s", e) + except Exception: + log.warning( + "fx graph cache unable to serialize compiled graph", exc_info=True + ) counters["inductor"]["fxgraph_cache_pickle_error"] += 1 return - if local: - subdir = FxGraphCache._get_tmp_dir_for_key(key) - if not os.path.exists(subdir): - os.makedirs(subdir, exist_ok=True) - - # Use a hash of the serialized CompiledFxGraph to get a unique file - # name. The specific name doesn't matter since a lookup involves - # iterating over all entries in the parent subdir. - path = os.path.join(subdir, sha256_hash(content)) - write_atomic(path, content, make_dirs=True) - - if remote_cache: - cache_data = ( - { - "data": content, - "time_taken_ms": time_taken_ns // 1000000, # Convert from NS to MS - } - if config.is_fbcode() - else content - ) - remote_cache.put(key, cache_data) + try: + if local: + subdir = FxGraphCache._get_tmp_dir_for_key(key) + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + + # Use a hash of the serialized CompiledFxGraph to get a unique file + # name. The specific name doesn't matter since a lookup involves + # iterating over all entries in the parent subdir. + path = os.path.join(subdir, sha256_hash(content)) + write_atomic(path, content, make_dirs=True) + + if remote_cache: + cache_data = ( + { + "data": content, + "time_taken_ms": time_taken_ns + // 1000000, # Convert from NS to MS + } + if config.is_fbcode() + else content + ) + remote_cache.put(key, cache_data) + except Exception: + log.warning("fx graph unable to write to cache", exc_info=True) + counters["inductor"]["fxgraph_cache_write_error"] += 1 @staticmethod def _check_can_cache(gm: torch.fx.GraphModule): From 831efeeadf5fa8d9e7f973057e634a57e3bcf04b Mon Sep 17 00:00:00 2001 From: chilli Date: Sun, 19 May 2024 20:50:19 -0700 Subject: [PATCH 22/35] Fix flexattention not realizing inputs before lowering (also refactored runtime estimation) (#126615) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126615 Approved by: https://github.com/yanboliang, https://github.com/drisspg, https://github.com/xmfan --- test/inductor/test_control_flow.py | 1 + test/inductor/test_flex_attention.py | 55 ++++++++++++++++++++-- test/inductor/test_group_batch_fusion.py | 1 + test/inductor/test_perf.py | 27 ++++------- test/inductor/test_snode_runtime.py | 14 ++++-- torch/_inductor/compile_fx.py | 31 ++----------- torch/_inductor/graph.py | 6 +-- torch/_inductor/kernel/flex_attention.py | 4 ++ torch/_inductor/metrics.py | 45 +++++++----------- torch/_inductor/scheduler.py | 56 +++++++++++++++++------ torch/nn/attention/_flex_attention.py | 24 +++++----- torch/testing/_internal/inductor_utils.py | 6 --- 12 files changed, 149 insertions(+), 121 deletions(-) diff --git a/test/inductor/test_control_flow.py b/test/inductor/test_control_flow.py index 47a5980b6d79c..833693dab934a 100644 --- a/test/inductor/test_control_flow.py +++ b/test/inductor/test_control_flow.py @@ -2,6 +2,7 @@ import itertools import torch +import torch._dynamo.testing from torch._inductor.test_case import TestCase from torch.testing._internal.common_utils import ( diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index f3a9026a3c805..02277ade44370 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -1,4 +1,5 @@ # Owner(s): ["module: inductor"] +# flake8: noqa: B950 import functools import unittest @@ -12,6 +13,7 @@ from torch._dynamo.testing import CompileCounterWithBackend, normalize_gm from torch._higher_order_ops.flex_attention import flex_attention as flex_attention_hop +from torch._inductor import metrics from torch._inductor.test_case import TestCase as InductorTestCase from torch._inductor.utils import run_and_get_code from torch.nn.attention._flex_attention import ( @@ -499,9 +501,8 @@ def score_mod_func(score, b, h, q, kv): ) query, key, value = make_tensor(), make_tensor(), make_tensor() # floor_div is not decomposed in decompostion_table is empty - gm = make_fx(_flex_attention, decomposition_table={})( - query, key, value, score_mod_func - ) + flex_attention = functools.partial(_flex_attention, score_mod=score_mod_func) + gm = make_fx(flex_attention, decomposition_table={})(query, key, value) self.assertExpectedInline( gm.sdpa_score0.code.strip(), """\ @@ -513,8 +514,8 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): ) # floor_div is decomposed for core_aten_decompositions - gm = make_fx(_flex_attention, decomposition_table=core_aten_decompositions())( - query, key, value, score_mod_func + gm = make_fx(flex_attention, decomposition_table=core_aten_decompositions())( + query, key, value ) self.assertExpectedInline( gm.sdpa_score0.code.strip(), @@ -645,6 +646,50 @@ def f(q, k1, k2, k3, v1, v2, v3): out2 = torch.compile(f)(query, *keys, *values) self.assertTrue((out - out2).abs().mean() < 1e-2) + @supported_platform + def test_inputs_are_realized(self): + def f(q, k, v): + x = torch.randn(1024, device="cuda") + x = x * 2 + + def func(qk, b, h, q, kv): + return qk + x[q] + + return _flex_attention(q.sin(), k, v, score_mod=func).cos() + + q, k, v = ( + torch.randn(1, 8, 1024, 64, device="cuda", requires_grad=True) + for _ in range(3) + ) + ref = f(q, k, v) + out = torch.compile(f)(q, k, v) + self.assertTrue((ref - out).abs().mean() < 1e-2) + gradOut = torch.randn_like(q) + + ref_grads = torch.autograd.grad(ref, (q, k, v), gradOut) + out_grads = torch.autograd.grad(out, (q, k, v), gradOut) + for ref, out in zip(ref_grads, out_grads): + self.assertTrue((ref - out).abs().mean() < 1e-2) + + @supported_platform + def test_epilogue_fused(self): + @torch.compile + def f(q, k, v): + out = _flex_attention(q, k, v) + return out.cos() + + q, k, v = (torch.randn(1, 8, 1024, 64, device="cuda") for _ in range(3)) + metrics.reset() + f(q, k, v) + accessed_bytes = 1 * 8 * 1024 * 64 * torch.float32.itemsize + num_accesses = 4 # q, k, v reads, one output. + # TODO: Get rid of this fudge factor + # We need this fudge factor for now, since + # 1. For some reason we materialize the output of the attention unnecessarily (it's related to the mutation somehow) + # 2. We also write the extraneous logsumexp + num_accesses += 2 + self.assertLess(metrics.num_bytes_accessed, accessed_bytes * num_accesses) + @supported_platform @skip("Triton bug ") # https://github.com/pytorch/pytorch/issues/124571 @common_utils.parametrize("dtype", test_dtypes) diff --git a/test/inductor/test_group_batch_fusion.py b/test/inductor/test_group_batch_fusion.py index 6dd2ff51219d7..b203a0f63e8b1 100644 --- a/test/inductor/test_group_batch_fusion.py +++ b/test/inductor/test_group_batch_fusion.py @@ -5,6 +5,7 @@ import torch import torch._inductor +import torch._inductor.fx_passes.group_batch_fusion from torch._dynamo.utils import counters, optimus_scuba_log from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.inductor_utils import HAS_CUDA diff --git a/test/inductor/test_perf.py b/test/inductor/test_perf.py index 09e913350e143..5e1af26f4bfac 100644 --- a/test/inductor/test_perf.py +++ b/test/inductor/test_perf.py @@ -8,9 +8,9 @@ import torch._inductor.config as config import torch.autograd from torch._inductor import metrics -from torch._inductor.compile_fx import compile_fx, count_bytes_inner +from torch._inductor.compile_fx import compile_fx, compile_fx_inner from torch._inductor.test_case import TestCase as InductorTestCase -from torch.testing._internal.common_utils import IS_WINDOWS, skipIfRocm +from torch.testing._internal.common_utils import skipIfRocm ######################## # Explanation of Tests # @@ -36,21 +36,12 @@ aten = torch.ops.aten -def count_bytes_inductor(gm, example_inputs): - return compile_fx(gm, example_inputs, inner_compile=count_bytes_inner) +def compile_but_use_eager(gm, example_inputs): + def inner_compile(gm, *args, **kwargs): + compile_fx_inner(gm, *args, **kwargs) + return gm - -# We don't support torch.compile() on Windows -if not IS_WINDOWS: - - @torch._dynamo.optimize(count_bytes_inductor) - def f(x): - return torch.cat([x, x.cos()]) - -else: - - def f(x): - return torch.cat([x, x.cos()]) + return compile_fx(gm, example_inputs, inner_compile=inner_compile) def count_numel(f, *args): @@ -58,7 +49,7 @@ def count_numel(f, *args): Assumes all inputs are fp32 """ metrics.reset() - torch._dynamo.optimize(count_bytes_inductor)(f)(*args) + torch.compile(f, backend=compile_but_use_eager)(*args) print(metrics.nodes_num_elem) return str(metrics.num_bytes_accessed // 4) @@ -69,7 +60,7 @@ def count_numel_train(f, *args): """ metrics.reset() - f = torch._dynamo.optimize(count_bytes_inductor)(f) + f = torch.compile(f, backend=compile_but_use_eager) out = f(*args) res = 0 for o in out: diff --git a/test/inductor/test_snode_runtime.py b/test/inductor/test_snode_runtime.py index b62c219f85e81..0d9ed849e0d50 100644 --- a/test/inductor/test_snode_runtime.py +++ b/test/inductor/test_snode_runtime.py @@ -8,7 +8,7 @@ from torch._inductor import metrics from torch._inductor.comm_analysis import estimate_nccl_collective_runtime -from torch._inductor.compile_fx import compile_fx, count_bytes_inner +from torch._inductor.compile_fx import compile_fx, compile_fx_inner from torch._inductor.test_case import TestCase as InductorTestCase from torch._inductor.utils import is_collective from torch.testing._internal.inductor_utils import HAS_CUDA @@ -18,8 +18,12 @@ _c10d = torch.ops._c10d_functional -def count_bytes_inductor(gm, example_inputs): - return compile_fx(gm, example_inputs, inner_compile=count_bytes_inner) +def compile_but_use_eager(gm, example_inputs): + def inner_compile(gm, *args, **kwargs): + compile_fx_inner(gm, *args, **kwargs) + return gm + + return compile_fx(gm, example_inputs, inner_compile=inner_compile) def calculate_runtime(f, *args) -> float: @@ -27,7 +31,7 @@ def calculate_runtime(f, *args) -> float: Assumes all inputs are fp32 """ metrics.reset() - torch._dynamo.optimize(count_bytes_inductor)(f)(*args) + torch.compile(f, backend=compile_but_use_eager)(*args) print(metrics.node_runtimes) ret = 0.0 @@ -187,7 +191,7 @@ def _verify_runtime_estimation(self, fn, inps): ) try: metrics.reset() - torch._dynamo.optimize(count_bytes_inductor)(fn)(*inps) + torch.compile(fn)(*inps) found_collective = False for snode, runtime in metrics.node_runtimes: if not is_collective(snode.node): diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index bdbfef2eee28f..5ad2b40894189 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -338,31 +338,6 @@ def maybe_disable_comprehensive_padding(example_inputs: List[torch.Tensor]): return contextlib.nullcontext() -@DebugContext.wrap -def count_bytes_inner( - gm: torch.fx.GraphModule, - example_inputs: List[torch.Tensor], - num_fixed: int = 0, - **kwargs, -): - shape_env = _shape_env_from_inputs(example_inputs) - fake_mode = fake_tensor_prop(gm, example_inputs) - - with V.set_fake_mode(fake_mode): - _recursive_post_grad_passes(gm, False) - - graph = GraphLowering(gm, shape_env=shape_env, num_static_inputs=num_fixed) - with V.set_graph_handler(graph), V.set_real_inputs( - example_inputs - ), maybe_disable_comprehensive_padding(example_inputs): - graph.run(*example_inputs) - num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes() - metrics.num_bytes_accessed += num_bytes - metrics.nodes_num_elem += nodes_num_elem - metrics.node_runtimes += node_runtimes - return make_boxed_func(gm.forward) - - def fake_tensor_prop( gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], @@ -795,6 +770,7 @@ def fx_codegen_and_compile( const_code=const_code, const_module=const_graph, ) + metrics_helper = metrics.CachedMetricsHelper() with V.set_graph_handler(graph): graph.run(*example_inputs) output_strides: List[Optional[Tuple[int, ...]]] = [] @@ -814,8 +790,11 @@ def fx_codegen_and_compile( else: output_strides.append(None) - metrics_helper = metrics.CachedMetricsHelper() compiled_fn = graph.compile_to_fn() + num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes() + metrics.num_bytes_accessed += num_bytes + metrics.node_runtimes += node_runtimes + metrics.nodes_num_elem += nodes_num_elem if ( cudagraphs diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index bfb7b8dea7ebb..0adf356f6262a 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1656,14 +1656,10 @@ def codegen_subgraph(self, parent_graph): self.scheduler.codegen() def count_bytes(self): - from .scheduler import Scheduler - - scheduler = Scheduler(self.buffers) - total_bytes = 0 node_counts = [] node_runtimes = [] - for node in scheduler.nodes: + for node in self.scheduler.nodes: num_bytes = node.get_read_write_buffers_sizes() total_bytes += num_bytes node_counts.append((node, num_bytes // 4)) diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 32dff9d46668c..46e72d8221f72 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -367,6 +367,8 @@ def _get_default_config_bwd(query) -> Tuple[int, int, int, int]: @register_lowering(torch.ops.higher_order.flex_attention, type_promotion_kind=None) def flex_attention(*args, **kwargs): query, key, value, subgraph, *other_buffers = args + for buf in [query, key, value]: + buf.realize() placeholder_inps = [ create_placeholder(name, dtype, query.get_device()) for name, dtype in [ @@ -640,6 +642,8 @@ def flex_attention_backward(*args, **kwargs): joint_graph, *other_buffers, ) = args + for buf in [query, key, value, grad_out]: + buf.realize() device = query.get_device() dtype = query.get_dtype() diff --git a/torch/_inductor/metrics.py b/torch/_inductor/metrics.py index 3a1b83045f4a0..76f15243c5ba1 100644 --- a/torch/_inductor/metrics.py +++ b/torch/_inductor/metrics.py @@ -1,6 +1,7 @@ from __future__ import annotations import csv +import dataclasses import inspect import os import re @@ -78,6 +79,11 @@ class CachedMetricsDeltas: generated_cpp_vec_kernel_count: int ir_nodes_pre_fusion: int cpp_to_dtype_count: int + num_bytes_accessed: int + + +def get_metric_fields(): + return [field.name for field in dataclasses.fields(CachedMetricsDeltas)] class CachedMetricsHelper: @@ -88,40 +94,21 @@ class CachedMetricsHelper: """ def __init__(self): - global generated_kernel_count - global generated_cpp_vec_kernel_count - global ir_nodes_pre_fusion - global cpp_to_dtype_count - - self.generated_kernel_count = generated_kernel_count - self.generated_cpp_vec_kernel_count = generated_cpp_vec_kernel_count - self.ir_nodes_pre_fusion = ir_nodes_pre_fusion - self.cpp_to_dtype_count = cpp_to_dtype_count + self.cached_metrics = {} + for metric in get_metric_fields(): + self.cached_metrics[metric] = globals()[metric] def get_deltas(self) -> CachedMetricsDeltas: - global generated_kernel_count - global generated_cpp_vec_kernel_count - global ir_nodes_pre_fusion - global cpp_to_dtype_count - - return CachedMetricsDeltas( - generated_kernel_count - self.generated_kernel_count, - generated_cpp_vec_kernel_count - self.generated_cpp_vec_kernel_count, - ir_nodes_pre_fusion - self.ir_nodes_pre_fusion, - cpp_to_dtype_count - self.cpp_to_dtype_count, - ) + delta_metrics = {} + for metric in get_metric_fields(): + delta_metrics[metric] = globals()[metric] - self.cached_metrics[metric] + + return CachedMetricsDeltas(**delta_metrics) @staticmethod def apply_deltas(delta: CachedMetricsDeltas): - global generated_kernel_count - global generated_cpp_vec_kernel_count - global ir_nodes_pre_fusion - global cpp_to_dtype_count - - generated_kernel_count += delta.generated_kernel_count - generated_cpp_vec_kernel_count += delta.generated_cpp_vec_kernel_count - ir_nodes_pre_fusion += delta.ir_nodes_pre_fusion - cpp_to_dtype_count += delta.cpp_to_dtype_count + for metric in get_metric_fields(): + globals()[metric] += getattr(delta, metric) REGISTERED_METRIC_TABLES: Dict[str, MetricTable] = {} diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index ec4763160a7b6..edaa944722700 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -28,6 +28,7 @@ import torch from torch._dynamo.utils import counters, dynamo_timed from torch._inductor.metrics import get_metric_table, is_metric_table_enabled +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols from torch.utils._sympy.symbol import free_symbol_is_type, SymT from torch.utils._triton import has_triton @@ -505,12 +506,16 @@ def get_read_write_buffers_sizes(self) -> int: if isinstance(self, ExternKernelSchedulerNode) and isinstance( self.node, MultiOutput ): + # todo: Calculate this - it's kinda annoying. return 0 + def try_size_hint(s): + return V.graph.sizevars.size_hint(s, fallback=0) + if isinstance(self, SchedulerNode): - node_numel = V.graph.sizevars.size_hint( + node_numel = try_size_hint( sympy_product(self.get_ranges()[0]) - * sympy_product(self.get_ranges()[1]) + * sympy_product(self.get_ranges()[1]), ) else: node_numel = int(1e9) @@ -545,16 +550,24 @@ def is_materialized(buf, snodes): continue def get_buf_elems(buf): - return V.graph.sizevars.size_hint(sympy_product(buf.get_size())) - - # Kind of a lazy way to get the MultiOutput nodes corresponding to - # a MultiOutputLayout - if isinstance(buf.layout, MultiOutputLayout): - users = self.scheduler.name_to_node[buf.get_name()].users - buf_elems = sum(get_buf_elems(user.node.node) for user in users) - else: - buf_elems = get_buf_elems(buf) + # Kind of a lazy way to get the MultiOutput nodes corresponding to + # a MultiOutputLayout + if isinstance(buf.layout, MultiOutputLayout): + users = self.scheduler.name_to_node[buf.get_name()].users + tot = 0 + for user in users: + if isinstance(user.node.node, MultiOutput): + tot += get_buf_elems(user.node.node) + else: + # Buf is a MultiOutputLayout but not all of its + # users are MultiOutputs... + # TODO: Figure out what's going on + return 0 + return tot + else: + return try_size_hint(sympy_product(buf.get_size())) + buf_elems = get_buf_elems(buf) node_bytes += min(buf_elems, buf_accessed_elems) * get_dtype_size( buf.get_dtype() ) @@ -580,13 +593,20 @@ def get_estimated_runtime(self) -> float: layout = self.node.get_layout() dtype = self.node.get_dtype() - if not is_gpu(layout.device.type): + if layout.device is not None and not is_gpu(layout.device.type): # default to no reordering based on runtime return 0 # Collective kernels if is_collective(self.node): - return estimate_nccl_collective_runtime(self.node) + try: + return estimate_nccl_collective_runtime(self.node) + except ValueError as e: + # We don't know how to estimate runtime for this collective, + # falling back to 0 + log.info(e) + return 0 + elif is_wait(self.node): # ir.Wait is only used for collective ops. # The time needed for the collective op is already estimated and considered @@ -611,7 +631,14 @@ def get_estimated_runtime(self) -> float: from torch._subclasses.fake_tensor import FakeTensorMode from torch.utils.flop_counter import FlopCounterMode - assert self.node.fx_node is not None + if any( + len(free_unbacked_symbols(n.get_numel())) > 0 + for n in self.node.inputs + ): + # Tensor has unbacked symints, we don't know how to estimate + # runtime for that today + return 0 + with FakeTensorMode() as fake_mode, FlopCounterMode( display=False ) as flop_counter_mode, V.set_current_node( @@ -619,7 +646,6 @@ def get_estimated_runtime(self) -> float: ), V.set_fake_mode( fake_mode ): - assert V.current_node is not None from .ir import ir_node_to_tensor fake_inputs = [ diff --git a/torch/nn/attention/_flex_attention.py b/torch/nn/attention/_flex_attention.py index c56374fcbc40d..bd999ec39118d 100644 --- a/torch/nn/attention/_flex_attention.py +++ b/torch/nn/attention/_flex_attention.py @@ -28,11 +28,21 @@ def inner(score, b, h, m, n): ] +def _identity( + score: torch.Tensor, + batch: torch.Tensor, + head: torch.Tensor, + token_q: torch.Tensor, + token_kv: torch.Tensor, +) -> torch.Tensor: + return score + + def _flex_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - score_mod: _score_mod_signature, + score_mod: _score_mod_signature = _identity, ) -> torch.Tensor: r"""This function implements scaled dot product attention with an arbitrary attention score modification function. @@ -63,7 +73,7 @@ def score_mod( query (Tensor): Query tensor; shape :math:`(B, H, L, E)`. key (Tensor): Key tensor; shape :math:`(B, H, S, E)`. value (Tensor): Value tensor; shape :math:`(B, H, S, Ev)`. - score_mod (Callable): Function to modify attention scores + score_mod (Callable): Function to modify attention scores. By default no score_mod is applied. Returns: output (Tensor): Attention output; shape :math:`(B, H, L, Ev)`. @@ -114,16 +124,6 @@ def score_mod( """Some common used score_mod functions for flex_attention in PyTorch.""" -def _identity( - score: torch.Tensor, - batch: torch.Tensor, - head: torch.Tensor, - token_q: torch.Tensor, - token_kv: torch.Tensor, -) -> torch.Tensor: - return score - - def _causal( score: torch.Tensor, batch: torch.Tensor, diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index 0f6209a01c3f1..e8db1e394b96f 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -12,8 +12,6 @@ LazyVal, IS_FBCODE, ) -from torch._dynamo.backends.registry import register_backend -from torch._inductor.compile_fx import compile_fx, count_bytes_inner from torch.testing._internal.common_utils import TestCase def test_cpu(): @@ -48,10 +46,6 @@ def test_cpu(): GPU_TYPE = "cuda" if len(tmp_gpus) == 0 else tmp_gpus.pop() del tmp_gpus -@register_backend -def count_bytes_inductor(gm, example_inputs): - return compile_fx(gm, example_inputs, inner_compile=count_bytes_inner) - def _check_has_dynamic_shape( self: TestCase, code, From da2292ce6b37028746bf5beeae04442eef1e803d Mon Sep 17 00:00:00 2001 From: chilli Date: Mon, 20 May 2024 11:49:46 -0700 Subject: [PATCH 23/35] Prevent partitioner from ever saving views (#126446) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126446 Approved by: https://github.com/anijain2305 ghstack dependencies: #126615 --- .../test_replicate_with_compiler.py | 20 +++++++++++++------ test/dynamo/test_repros.py | 3 +-- test/inductor/test_flex_attention.py | 8 +++++++- torch/_functorch/partitioners.py | 15 ++++++++++++-- 4 files changed, 35 insertions(+), 11 deletions(-) diff --git a/test/distributed/_composable/test_replicate_with_compiler.py b/test/distributed/_composable/test_replicate_with_compiler.py index 354f99dabd739..d4960681b27ed 100644 --- a/test/distributed/_composable/test_replicate_with_compiler.py +++ b/test/distributed/_composable/test_replicate_with_compiler.py @@ -279,12 +279,16 @@ def bwd(loss): self.assertEqual(counters["inductor"]["ddp_buckets"], 3) return code - def test_bucketing_coalesced_op(self): - torch._inductor.config._fuse_ddp_communication_passes = [ + @torch._inductor.config.patch( + _fuse_ddp_communication_passes=[ "fuse_ddp_with_coalesced_op", "schedule_comm_wait", ] - + ) + # todo: This pass mucks things up since Inductor thinks its inference + # and can apply this. Should turn off these passes in compiled autograd + @torch._inductor.config.patch(reorder_for_locality=False) + def test_bucketing_coalesced_op(self): # Gradient is None code = self._test_bucketing() self.assertEqual(counters["inductor"]["ddp_buckets"], 3) @@ -311,12 +315,16 @@ def test_bucketing_coalesced_op(self): fc.run(code) - def test_bucketing_concat_op(self): - torch._inductor.config._fuse_ddp_communication_passes = [ + @torch._inductor.config.patch( + _fuse_ddp_communication_passes=[ "fuse_ddp_with_concat_op", "schedule_comm_wait", ] - + ) + # todo: This pass mucks things up since Inductor thinks its inference + # and can apply this. Should turn off these passes in compiled autograd + @torch._inductor.config.patch(reorder_for_locality=False) + def test_bucketing_concat_op(self): # Gradient is None code = self._test_bucketing() self.assertEqual(counters["inductor"]["ddp_buckets"], 3) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 96bf924e09990..85b95370db240 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -4709,8 +4709,7 @@ def forward(self, primals_1, primals_2): _foreach_copy = torch.ops.aten._foreach_copy.default([primals_1], [primals_2]); primals_1 = primals_2 = None getitem = _foreach_copy[0]; _foreach_copy = None mm = torch.ops.aten.mm.default(getitem, getitem) - t_1 = torch.ops.aten.t.default(getitem); getitem = None - return [mm, t_1]""", + return [mm, getitem]""", ) self.assertEqual(out_ref, out_test) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 02277ade44370..f9a14d446aa3a 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -964,7 +964,13 @@ def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs): joint_graph, """\ class GraphModule(torch.nn.Module): - def forward(self, primals_1: "f64[2, 2, 8, 4]", primals_2: "f64[2, 2, 8, 4]", primals_3: "f64[2, 2, 8, 4]", alias_3: "f64[2, 2, 8, 4]", alias_5: "f32[2, 2, 8]", tangents_1: "f64[2, 2, 8, 4]"): + def forward(self, primals_1: "f64[2, 2, 8, 4]", primals_2: "f64[2, 2, 8, 4]", primals_3: "f64[2, 2, 8, 4]", getitem: "f64[2, 2, 8, 4]", getitem_1: "f32[2, 2, 8]", tangents_1: "f64[2, 2, 8, 4]"): + alias: "f64[2, 2, 8, 4]" = torch.ops.aten.alias.default(getitem); getitem = None + alias_2: "f64[2, 2, 8, 4]" = torch.ops.aten.alias.default(alias); alias = None + alias_3: "f64[2, 2, 8, 4]" = torch.ops.aten.alias.default(alias_2); alias_2 = None + alias_1: "f32[2, 2, 8]" = torch.ops.aten.alias.default(getitem_1); getitem_1 = None + alias_4: "f32[2, 2, 8]" = torch.ops.aten.alias.default(alias_1); alias_1 = None + alias_5: "f32[2, 2, 8]" = torch.ops.aten.alias.default(alias_4); alias_4 = None fw_graph = self.fw_graph joint_graph = self.joint_graph flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, alias_3, alias_5, tangents_1, fw_graph, joint_graph); primals_1 = primals_2 = primals_3 = alias_3 = alias_5 = tangents_1 = fw_graph = joint_graph = None diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 0956ee7e367c4..b1a9502bcf3dd 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -533,12 +533,13 @@ def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule: # Populate depth for the nodes. Depth is the distance from the inputs. depths = {} - output_node = next(iter(gm.graph.find_nodes(op="output"))) for node in gm.graph.nodes: if node.op == "placeholder": depths[node] = 0 else: - depths[node] = max([depths[arg] for arg in node.all_input_nodes], default=0) + depths[node] = ( + max((depths[arg] for arg in node.all_input_nodes), default=0) + 1 + ) def insert_node_in_graph(node): if node in env: @@ -802,6 +803,8 @@ def should_ban_recomputation(node): return False if node.target == operator.getitem: return False + if op_types.is_view(node): + return False if node.target in [aten.lift_fresh_copy.default, aten.lift_fresh.default]: return False # NB: "recompute" == 0 means that must save this node. @@ -854,6 +857,14 @@ def is_materialized(node): def get_node_weight(node) -> float: mem_sz = _size_of(node) + if op_types.is_view(node): + # We never choose to save views, since views are free to recompute. + # It makes it a bit simpler to analyze + # NB: If they're not free to recompute (e.g. nested tensors)... I + # think we should modify checks for view_ops to `is_view` and check + # that. Basically, with nested tensors, `aten.view` is not a "view + # op". + return math.inf if isinstance(node.meta["val"], py_sym_types): # We never want to save symfloats From cd3a71f754a2248bcfe500de7c9860bd7d2002bf Mon Sep 17 00:00:00 2001 From: chilli Date: Mon, 20 May 2024 11:49:46 -0700 Subject: [PATCH 24/35] Fix silu test for flexattention (#126641) Pull Request resolved: https://github.com/pytorch/pytorch/pull/126641 Approved by: https://github.com/ezyang, https://github.com/drisspg ghstack dependencies: #126615, #126446 --- test/inductor/test_flex_attention.py | 2 -- torch/_inductor/kernel/flex_attention.py | 5 +++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index f9a14d446aa3a..245e8f16ab641 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -2,7 +2,6 @@ # flake8: noqa: B950 import functools -import unittest from collections import namedtuple from typing import Callable, Optional @@ -529,7 +528,6 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): @supported_platform @common_utils.parametrize("dtype", test_dtypes_fast) - @unittest.skip("Silu decomp failing for full in backwards") def test_silu_on_score(self, dtype): def silu_score(score, b, h, q, kv): return torch.nn.functional.silu(score) diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 46e72d8221f72..ddddbed11c829 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -115,9 +115,10 @@ def build_subgraph_buffer( # already created TensorBoxes as args from torch.utils._pytree import tree_map - env[node] = lowerings[node.target]( - *tree_map(lambda x: env[x] if x in env else x, node.args) + args, kwargs = tree_map( + lambda x: env[x] if x in env else x, (node.args, node.kwargs) ) + env[node] = lowerings[node.target](*args, **kwargs) elif node.op == "output": # For the output node we need to create a ComputedBuffer # which represents the actual score modification From b85f9d7fa2c9d279bcded5e312df8c3d1b09451d Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Mon, 20 May 2024 06:35:35 -0700 Subject: [PATCH 25/35] Add symbolic_shape_specialization structured trace (#126450) This is typically the information you want when diagnosing why something overspecialized in dynamic shapes. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/126450 Approved by: https://github.com/albanD --- torch/fx/experimental/symbolic_shapes.py | 27 ++++++++++++++++++------ 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index be1be24137f88..7398fd4bf56a0 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -4397,6 +4397,9 @@ def _set_replacement(self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str) -> No Use this instead of `self.replacements[a] = tgt`. """ + if tgt == self.replacements.get(a, None): + return + # Precondition: a == tgt assert isinstance(a, sympy.Symbol) @@ -4487,14 +4490,24 @@ def issubset(x, y): "[%s not subset of %s (size-oblivious conditions)]", a, tgt, msg, tgt_bound_so, src_bound_so) return - if config.print_specializations and isinstance(tgt, (sympy.Integer, sympy.Float)): - # specializing to a constant, which is likely unexpected + if isinstance(tgt, (sympy.Integer, sympy.Float)): + # specializing to a constant, which is likely unexpected (unless + # you specified dynamic=True) + + user_tb = TracingContext.extract_stack() + trace_structured( + "symbolic_shape_specialization", + metadata_fn=lambda: { + "symbol": repr(a), + "sources": [s.name() for s in self.var_to_sources.get(a, [])], + "value": repr(tgt), + "reason": msg, + "stack": structured.from_traceback(CapturedTraceback.extract(skip=1).summary()), + "user_stack": structured.from_traceback(user_tb) if user_tb else None, + } + ) - # NOTE(avik): It is possible that we try logging the same specialization multiple times, e.g., - # when adding a to self.replacements, and again when simplifying an expression containing a. - # Thus to avoid duplication, checking whether a is in self.replacements isn't enough; if it is, - # it must not already map to `tgt`. Fortunately this check is cheap because `tgt` is a constant. - if a not in self.replacements or tgt != self.replacements[a]: + if config.print_specializations: self.log.warning("Specializing %s to %s", self.var_to_sources[a][0].name(), tgt) self.log.debug("SPECIALIZATION", stack_info=True) log.info("set_replacement %s = %s (%s) %s", a, tgt, msg, tgt_bound) From 4644611b140e90ceb7d3ac8a6a89ce2b6df7233a Mon Sep 17 00:00:00 2001 From: dshi7 Date: Tue, 21 May 2024 00:44:55 +0000 Subject: [PATCH 26/35] [cprofile] log manifold link instead of raw data to trace_structured (#126451) Internal D57459752 returns manifold URL and this PR adds to tlparse payload Pull Request resolved: https://github.com/pytorch/pytorch/pull/126451 Approved by: https://github.com/jamesjwu --- torch/_dynamo/convert_frame.py | 22 +++++++--------------- torch/_utils_internal.py | 6 +++--- 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index d5c24a67d9e25..6dcb84fab8fc1 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -1,4 +1,3 @@ -import base64 import collections import cProfile import dis @@ -350,20 +349,13 @@ def profile_wrapper(*args, **kwargs): ps.sort_stats(pstats.SortKey.TIME).print_stats(20) ps.sort_stats(pstats.SortKey.CUMULATIVE).print_stats(20) - maybe_upload_prof_stats_to_manifold(str(profile_path)) # fb-only - - torch._logging.trace_structured( - "artifact", - lambda: { - "name": "dynamo_cprofile_prof", - "type": "prof", - "encoding": "base64", - }, - payload_fn=lambda: base64.encodebytes( - open(profile_path, "rb").read() - ).decode("ascii"), - ) - + if manifold_link := maybe_upload_prof_stats_to_manifold( + str(profile_path) + ): # fb-only + torch._logging.trace_structured( + "link", + lambda: {"name": "cprofile_manifold_url", "url": manifold_link}, + ) return retval return profile_wrapper diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index a61adb6a826b6..ac07f588107a2 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -3,7 +3,7 @@ import os import sys import tempfile -from typing import Any, Dict +from typing import Any, Dict, Optional import torch @@ -195,6 +195,6 @@ def max_clock_rate(): REQUIRES_SET_PYTHON_MODULE = False -def maybe_upload_prof_stats_to_manifold(profile_path: str) -> None: +def maybe_upload_prof_stats_to_manifold(profile_path: str) -> Optional[str]: print("Uploading profile stats (fb-only otherwise no-op)") - pass + return None From 82b4528788d0b946ce525b9c9b07b2ceb82f44f7 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 21 May 2024 00:55:15 +0000 Subject: [PATCH 27/35] [cudagraph] fix verbose graph logging (#126694) According to the [doc](https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g0907ca7a1e7d0211b71ee49c5403072b): > enum cudaGraphDebugDotFlags > CUDA Graph debug write options > > Values > cudaGraphDebugDotFlagsVerbose = 1<<0 > Output all debug data as if every debug flag is enabled > cudaGraphDebugDotFlagsKernelNodeParams = 1<<2 > Adds cudaKernelNodeParams to output > cudaGraphDebugDotFlagsMemcpyNodeParams = 1<<3 > Adds cudaMemcpy3DParms to output > cudaGraphDebugDotFlagsMemsetNodeParams = 1<<4 > Adds cudaMemsetParams to output > cudaGraphDebugDotFlagsHostNodeParams = 1<<5 > Adds cudaHostNodeParams to output > cudaGraphDebugDotFlagsEventNodeParams = 1<<6 > Adds cudaEvent_t handle from record and wait nodes to output > cudaGraphDebugDotFlagsExtSemasSignalNodeParams = 1<<7 > Adds cudaExternalSemaphoreSignalNodeParams values to output > cudaGraphDebugDotFlagsExtSemasWaitNodeParams = 1<<8 > Adds cudaExternalSemaphoreWaitNodeParams to output > cudaGraphDebugDotFlagsKernelNodeAttributes = 1<<9 > Adds cudaKernelNodeAttrID values to output > cudaGraphDebugDotFlagsHandles = 1<<10 > Adds node handles and every kernel function handle to output > cudaGraphDebugDotFlagsConditionalNodeParams = 1<<15 > Adds cudaConditionalNodeParams to output > `1 << 10` is not the most verbose flag. it is just one flag to add node handles and every kernel function handle to output. `1 << 0` is the most verbose flag, under the name `cudaGraphDebugDotFlagsVerbose`. Here is an example of graph, dumped with `1 << 10`: ```dot digraph dot { subgraph cluster_1 { label="graph_1" graph[style="dashed"]; "graph_1_node_0"[style="solid" shape="rectangle" label="0 MEM_ALLOC node handle: 0x000055D2889750F0 "]; "graph_1_node_1"[style="bold" shape="octagon" label="1 _Z3addPhS_S_m node handle: 0x000055D288979A20 func handle: 0x000055D288978D40 "]; "graph_1_node_2"[style="solid" shape="trapezium"label="2 MEMCPY node handle: 0x000055D28897A130 (DtoH,1024) "]; "graph_1_node_3"[style="solid" shape="rectangle" label="3 MEM_FREE node handle: 0x000055D2889890C0 "]; "graph_1_node_0" -> "graph_1_node_1"; "graph_1_node_1" -> "graph_1_node_2"; "graph_1_node_2" -> "graph_1_node_3"; } } ``` The same graph dumped with `1 << 0`: ```dot digraph dot { subgraph cluster_1 { label="graph_1" graph[style="dashed"]; "graph_1_node_0"[style="solid" shape="record" label="{ MEM_ALLOC | {{ID | node handle} | {0 (topoId: 3) | 0x000055D2889750F0}} | {{{poolProps | {allocType | handleTypes | {location | {type | id}}} | {PINNED | NONE | DEVICE | 0}}}} | {{bytesize | dptr} | {1024 | 0x0000000A02000000}} }"]; "graph_1_node_1"[style="bold" shape="record" label="{KERNEL | {ID | 1 (topoId: 2) | _Z3addPhS_S_m\<\<\<4,256,0\>\>\>} | {{node handle | func handle} | {0x000055D288979A20 | 0x000055D288978D40}} | {accessPolicyWindow | {base_ptr | num_bytes | hitRatio | hitProp | missProp} | {0x0000000000000000 | 0 | 0.000000 | N | N}} | {cooperative | 0} | {priority | 0} }"]; "graph_1_node_2"[style="solid" shape="record" label="{ MEMCPY | {{ID | node handle} | {2 (topoId: 1) | 0x000055D28897A130}} | {kind | DtoH (DEVICE to HOST PAGEABLE)} | {{srcPtr | dstPtr} | {pitch | ptr | xsize | ysize | pitch | ptr | xsize | ysize} | {0 | 0x0000000A02000000 | 0 | 0 | 0 | 0x000055D287CA6DB0 | 0 | 0}} | {{srcPos | {{x | 0} | {y | 0} | {z | 0}}} | {dstPos | {{x | 0} | {y | 0} | {z | 0}}} | {Extent | {{Width | 1024} | {Height | 1} | {Depth | 1}}}} }"]; "graph_1_node_3"[style="solid" shape="record" label="{ MEM_FREE | {{ID | node handle} | {3 (topoId: 0) | 0x000055D2889890C0}} | {{dptr} | {0x0000000A02000000}} }"]; "graph_1_node_0" -> "graph_1_node_1" [headlabel=0]; "graph_1_node_1" -> "graph_1_node_2" [headlabel=0]; "graph_1_node_2" -> "graph_1_node_3" [headlabel=0]; } } ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/126694 Approved by: https://github.com/eqy, https://github.com/eellison --- aten/src/ATen/cuda/CUDAGraph.cpp | 2 +- torch/utils/hipify/constants.py | 2 +- torch/utils/hipify/cuda_to_hip_mappings.py | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/cuda/CUDAGraph.cpp b/aten/src/ATen/cuda/CUDAGraph.cpp index 01d3d513c4ebb..e93a8561b2ced 100644 --- a/aten/src/ATen/cuda/CUDAGraph.cpp +++ b/aten/src/ATen/cuda/CUDAGraph.cpp @@ -268,7 +268,7 @@ void CUDAGraph::debug_dump(const std::string& debug_path) { TORCH_WARN("DEBUG: calling debug_dump()"); if (has_graph_) { TORCH_WARN("DEBUG: calling cudaGraphDebugDotPrint() with ", debug_path); - C10_CUDA_CHECK_WARN(cudaGraphDebugDotPrint(graph_, debug_path.c_str(), 1<<10)); // most verbose output + C10_CUDA_CHECK_WARN(cudaGraphDebugDotPrint(graph_, debug_path.c_str(), cudaGraphDebugDotFlagsVerbose)); // most verbose output AT_CUDA_CHECK(cudaGraphDestroy(graph_)); } } else { diff --git a/torch/utils/hipify/constants.py b/torch/utils/hipify/constants.py index fb56e7a77a3ed..a9053b261ad44 100644 --- a/torch/utils/hipify/constants.py +++ b/torch/utils/hipify/constants.py @@ -2,7 +2,7 @@ The constants defined here are used to annotate the mapping tuples in cuda_to_hip_mappings.py. They are based on -https://github.com/ROCm-Developer-Tools/HIP/blob/master/hipify-clang/src/Statistics.h +https://github.com/ROCm/HIPIFY/blob/master/src/Statistics.h and fall in three categories: 1) type of mapping, 2) API of mapping, 3) unsupported mapping. """ diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py index 3c84e1bff4c9d..976e12e42d336 100644 --- a/torch/utils/hipify/cuda_to_hip_mappings.py +++ b/torch/utils/hipify/cuda_to_hip_mappings.py @@ -4163,6 +4163,7 @@ ("cudaGraphLaunch", ("hipGraphLaunch", CONV_TYPE, API_RUNTIME)), ("cudaGraphGetNodes", ("hipGraphGetNodes", CONV_TYPE, API_RUNTIME)), ("cudaGraphDebugDotPrint", ("hipGraphDebugDotPrint", CONV_TYPE, API_RUNTIME)), + ("cudaGraphDebugDotFlagsVerbose", ("hipGraphDebugDotFlagsVerbose", CONV_NUMERIC_LITERAL, API_RUNTIME)), ("cudaGraphRetainUserObject", ("hipGraphRetainUserObject", CONV_TYPE, API_RUNTIME)), ("cudaGraphUserObjectMove", ("hipGraphUserObjectMove", CONV_TYPE, API_RUNTIME)), ("cudaUserObject_t", ("hipUserObject_t", CONV_TYPE, API_RUNTIME)), From 31ba6ee49bdbd7dd6c2a3c77ad31cc46aec2049e Mon Sep 17 00:00:00 2001 From: Joel Schlosser Date: Mon, 20 May 2024 16:13:25 -0400 Subject: [PATCH 28/35] Traceable wrapper subclass support for deferred runtime asserts (#126198) The padded dense -> jagged conversion op has the signature: ``` _fbgemm_dense_to_jagged_forward(Tensor dense, Tensor[] offsets, SymInt? total_L=None) -> Tensor ``` when `total_L` is not specified, the meta registration has a data-dependent output shape (based on `offsets[0][-1]`). Returning an unbacked SymInt here should work in theory, but traceable wrapper subclass support is missing in later code to handle deferred runtime asserts. This PR fixes this. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126198 Approved by: https://github.com/ezyang --- docs/source/fx.experimental.rst | 1 + test/allowlist_for_publicAPI.json | 3 ++- torch/fx/experimental/symbolic_shapes.py | 20 ++++++++++++++++++++ torch/fx/passes/runtime_assert.py | 8 ++++++++ 4 files changed, 31 insertions(+), 1 deletion(-) diff --git a/docs/source/fx.experimental.rst b/docs/source/fx.experimental.rst index 6abfb89971cd9..d6885eb41ca07 100644 --- a/docs/source/fx.experimental.rst +++ b/docs/source/fx.experimental.rst @@ -30,6 +30,7 @@ torch.fx.experimental.symbolic_shapes CallMethodKey PropagateUnbackedSymInts DivideByKey + InnerTensorKey hint_int is_concrete_int diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index ac77325188ee0..c3d3fe2f00ec8 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -2027,6 +2027,7 @@ "uninteresting_files", "CallMethodKey", "DivideByKey", + "InnerTensorKey", "PropagateUnbackedSymInts", "ShapeEnvSettings", "log_lru_cache_stats", @@ -2752,4 +2753,4 @@ "torch.utils.hipify.hipify_python": [ "TrieNode" ] -} \ No newline at end of file +} diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 7398fd4bf56a0..ca6e5957e20e6 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -489,6 +489,18 @@ def get(self, o: Any) -> Any: return getattr(o, self.name)() +@dataclass(frozen=True) +class InnerTensorKey: + inner_name: str + + def __str__(self) -> str: + return f".{self.inner_name}" + + def get(self, o: Any) -> Any: + """Get the inner tensor attribute""" + return getattr(o, self.inner_name) + + @dataclass(frozen=True) class DivideByKey: divisor: int @@ -538,6 +550,14 @@ def free_unbacked_symbols_with_path( real=real[i] if real is not None else None ) ) + elif is_traceable_wrapper_subclass(a): + # TODO: Determine if this is correct + attrs, _ = a.__tensor_flatten__() + for attr in attrs: + sub = getattr(a, attr) + r.update( + free_unbacked_symbols_with_path(sub, path + (InnerTensorKey(attr),)) + ) elif isinstance(a, torch.Tensor): r.update( free_unbacked_symbols_with_path( diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index 0d45defe8a48c..843f5f37e1dab 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -57,6 +57,7 @@ def insert_deferred_runtime_asserts( ConvertIntKey, DivideByKey, free_symbols, + InnerTensorKey, ) from torch.utils._sympy.interp import sympy_interp from torch.utils._sympy.reference import PythonReferenceAnalysis @@ -225,6 +226,13 @@ def go(node, keypath): ), keypath[1:], ) + elif isinstance(keypath[0], InnerTensorKey): + return go( + graph.call_function( + getattr, (node, keypath[0].inner_name) + ), + keypath[1:], + ) else: raise AssertionError(f"unrecognized keypath {keypath}") From b948b1ad7a9cf61c9692506c60c295fd40e00f43 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 20 May 2024 15:32:40 -0700 Subject: [PATCH 29/35] [pipelining] Add pipeline stage test (#126721) Test tracer's and manual's stage creation by using a basic schedule (GPipe). (Migrated from https://github.com/pytorch/PiPPy/blob/main/test/test_pipeline_stage.py) Test command: ``` $ python test_stage.py ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/126721 Approved by: https://github.com/wconstab, https://github.com/H-Huang --- test/distributed/pipelining/model_registry.py | 26 ++- ...est_stage_backward.py => test_backward.py} | 0 test/distributed/pipelining/test_chunkspec.py | 4 +- test/distributed/pipelining/test_schedule.py | 53 +---- test/distributed/pipelining/test_stage.py | 198 ++++++++++++++++++ torch/distributed/pipelining/__init__.py | 3 +- 6 files changed, 228 insertions(+), 56 deletions(-) rename test/distributed/pipelining/{test_stage_backward.py => test_backward.py} (100%) create mode 100644 test/distributed/pipelining/test_stage.py diff --git a/test/distributed/pipelining/model_registry.py b/test/distributed/pipelining/model_registry.py index f88bebd3a5598..ca811de3d75db 100644 --- a/test/distributed/pipelining/model_registry.py +++ b/test/distributed/pipelining/model_registry.py @@ -17,9 +17,8 @@ def __init__(self, d_hid: int = default_dhid): self.lin0 = torch.nn.Linear(d_hid, d_hid) self.lin1 = torch.nn.Linear(d_hid, d_hid) - def forward(self, x, y=torch.zeros(default_batch_size, default_dhid)): + def forward(self, x): x = torch.mm(x, self.mm_param0) - x = x + y x = torch.relu(x) # try passing a value that doesn't require_grad across skip boundaries a_constant = self.cval.clone() @@ -32,6 +31,29 @@ def forward(self, x, y=torch.zeros(default_batch_size, default_dhid)): return x +class ModelWithKwargs(torch.nn.Module): + default_dhid = 512 + default_batch_size = 256 + + def __init__(self, d_hid: int = default_dhid): + super().__init__() + self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) + self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) + self.lin0 = torch.nn.Linear(d_hid, d_hid) + self.lin1 = torch.nn.Linear(d_hid, d_hid) + + def forward(self, x, y=torch.zeros(default_batch_size, default_dhid)): + x = torch.mm(x, self.mm_param0) + x = x + y + x = self.lin0(x) + x = torch.relu(x) + pipe_split() + x = torch.mm(x, self.mm_param1) + x = self.lin1(x) + x = torch.relu(x) + return x + + # MLP Layer class MLPModule(torch.nn.Module): def __init__(self, d_hid): diff --git a/test/distributed/pipelining/test_stage_backward.py b/test/distributed/pipelining/test_backward.py similarity index 100% rename from test/distributed/pipelining/test_stage_backward.py rename to test/distributed/pipelining/test_backward.py diff --git a/test/distributed/pipelining/test_chunkspec.py b/test/distributed/pipelining/test_chunkspec.py index 050a7b11a21bc..1b104e59ec779 100644 --- a/test/distributed/pipelining/test_chunkspec.py +++ b/test/distributed/pipelining/test_chunkspec.py @@ -16,7 +16,7 @@ torch.manual_seed(0) -class ExampleCode(torch.nn.Module): +class ModelWithKwargs(torch.nn.Module): def __init__(self): super().__init__() self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) @@ -44,7 +44,7 @@ def forward(self, x, y, z=torch.zeros(batch_size, d_hid)): class ChunkSpecTests(TestCase): def test_chunk_spec(self): - mod = ExampleCode() + mod = ModelWithKwargs() x = torch.randn(batch_size, d_hid) y = torch.randn(batch_size, d_hid) diff --git a/test/distributed/pipelining/test_schedule.py b/test/distributed/pipelining/test_schedule.py index 8357f3b66108d..c1fb6b075f766 100644 --- a/test/distributed/pipelining/test_schedule.py +++ b/test/distributed/pipelining/test_schedule.py @@ -8,7 +8,7 @@ import torch import torch.distributed as dist -from model_registry import ExampleCode, MultiMLP +from model_registry import ModelWithKwargs, MultiMLP from torch.distributed.pipelining import ( pipeline, PipelineStage, @@ -50,60 +50,11 @@ def setUpClass(cls): dev_id = cls.rank % torch.cuda.device_count() cls.device = torch.device(f"cuda:{dev_id}") - @requires_nccl() - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") - def test_ec_forward(self): - # Setting this flag for numerical stability - torch.distributed.pipelining.microbatch._debug_mask_minibatches = True - - mod = ExampleCode(d_hid) - mod.to(self.device) - - x = torch.randn(batch_size, d_hid, device=self.device) - y = torch.randn(batch_size, d_hid, device=self.device) - - pipe = pipeline( - mod, - chunks, - example_args=(x,), - example_kwargs={"y": y}, - ) - - stage = PipelineStage( - pipe, - self.rank, - device=self.device, - ) - - # Attach to a schedule - schedule = ScheduleGPipe(stage, chunks) - - # Run - if self.rank == 0: - schedule.step(x, y=y) - else: - out = schedule.step() - - dist.barrier() - - # Last rank checks result - if self.rank == self.world_size - 1: - ref_out = mod(x, y=y) - torch.testing.assert_close(out, ref_out) - - # Test qualname mapping - submod_keys = stage.submod.state_dict().keys() - # Confirm keys are consistent with original model - old_keys = mod.state_dict().keys() - assert all(k in old_keys for k in submod_keys) - # Reset this flag - torch.distributed.pipelining.microbatch._debug_mask_minibatches = False - @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) def test_ec_backward(self, ScheduleClass): - mod = ExampleCode(d_hid) + mod = ModelWithKwargs(d_hid) mod.to(self.device) x = torch.randn(batch_size, d_hid, device=self.device) diff --git a/test/distributed/pipelining/test_stage.py b/test/distributed/pipelining/test_stage.py new file mode 100644 index 0000000000000..20f40ea5fa298 --- /dev/null +++ b/test/distributed/pipelining/test_stage.py @@ -0,0 +1,198 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Owner(s): ["oncall: distributed"] +import os +import sys +import tempfile + +import torch +import torch.distributed as dist + +from model_registry import ExampleCode, ModelWithKwargs, MultiMLP +from torch.distributed.pipelining import ( + ManualPipelineStage, + pipeline, + PipelineStage, + ScheduleGPipe, +) +from torch.testing._internal.common_cuda import TEST_MULTIGPU +from torch.testing._internal.common_distributed import ( + MultiProcContinousTest, + requires_nccl, +) +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + skip_but_pass_in_sandcastle_if, +) + + +d_hid = 512 +batch_size = 256 +chunks = 4 + +torch.manual_seed(0) + + +class StageTest(MultiProcContinousTest): + @classmethod + def backend_str(cls) -> str: + # Testing with NCCL backend + return "nccl" + + @classmethod + def setUpClass(cls): + """ + Class-scope test fixture. Run once for entire test class, before any test starts. + Set up the device. + """ + super().setUpClass() + dev_id = cls.rank % torch.cuda.device_count() + cls.device = torch.device(f"cuda:{dev_id}") + + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @parametrize("ModelClass", [ExampleCode, MultiMLP]) + def test_tracer(self, ModelClass): + mod = ModelClass(d_hid) + mod.to(self.device) + + x = torch.randn(batch_size, d_hid, device=self.device) + + pipe = pipeline( + mod, + chunks, + example_args=(x,), + ) + + stage = PipelineStage( + pipe, + self.rank, + device=self.device, + ) + + # Attach to a schedule + schedule = ScheduleGPipe(stage, chunks) + + # Run + if self.rank == 0: + schedule.step(x) + else: + out = schedule.step() + + # Last rank checks result + if self.rank == self.world_size - 1: + ref_out = mod(x) + torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=5e-2) + + # Test qualname mapping + submod_keys = stage.submod.state_dict().keys() + # Confirm keys are consistent with original model + old_keys = mod.state_dict().keys() + assert all(k in old_keys for k in submod_keys) + + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @parametrize("ModelClass", [ModelWithKwargs]) + def test_tracer_kwargs(self, ModelClass): + mod = ModelClass(d_hid) + mod.to(self.device) + + x = torch.randn(batch_size, d_hid, device=self.device) + y = torch.randn(batch_size, d_hid, device=self.device) + + pipe = pipeline( + mod, + chunks, + example_args=(x,), + example_kwargs={"y": y}, + ) + + stage = PipelineStage( + pipe, + self.rank, + device=self.device, + ) + + # Attach to a schedule + schedule = ScheduleGPipe(stage, chunks) + + # Run + if self.rank == 0: + schedule.step(x, y=y) + else: + out = schedule.step() + + # Last rank checks result + if self.rank == self.world_size - 1: + ref_out = mod(x, y=y) + torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=5e-2) + + # Test qualname mapping + submod_keys = stage.submod.state_dict().keys() + # Confirm keys are consistent with original model + old_keys = mod.state_dict().keys() + assert all(k in old_keys for k in submod_keys) + + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + def test_manual(self): + full_mod = MultiMLP(d_hid).to(self.device) + stage_mod = full_mod.get_submodule(f"mlp{self.rank}") + stage_mod.to(self.device) + + x = torch.randn(batch_size, d_hid, device=self.device) + + stage = ManualPipelineStage( + stage_mod, + self.rank, + self.world_size, + self.device, + chunks, + input_args=x.chunk(chunks)[0], + ) + + # Attach to a schedule + schedule = ScheduleGPipe(stage, chunks) + + # Run + if self.rank == 0: + schedule.step(x) + else: + out = schedule.step() + + # Last rank checks result + if self.rank == self.world_size - 1: + ref_out = full_mod(x) + torch.testing.assert_close(out, ref_out) + + +instantiate_parametrized_tests(StageTest) + +if __name__ == "__main__": + # Check if GPU and NCCL are available + if not ( + dist.is_available() + and dist.is_nccl_available() + and torch.cuda.device_count() > 1 + ): + print( + "c10d NCCL not available or not enough GPUs, skipping tests", + file=sys.stderr, + ) + sys.exit(0) + + rank = int(os.getenv("RANK", -1)) + world_size = int(os.getenv("WORLD_SIZE", 2)) + + if rank != -1: + # Launched with torchrun or other multi-proc launchers. Directly run the test. + StageTest.run_rank(rank, world_size) + else: + # Launched as a single process. Spawn subprocess to run the tests. + # Also need a rendezvous file for `init_process_group` purpose. + rdvz_file = tempfile.NamedTemporaryFile(delete=False).name + torch.multiprocessing.spawn( + StageTest.run_rank, + nprocs=world_size, + args=(world_size, rdvz_file), + ) diff --git a/torch/distributed/pipelining/__init__.py b/torch/distributed/pipelining/__init__.py index 8ea9923dd44d1..45352d3da1b90 100644 --- a/torch/distributed/pipelining/__init__.py +++ b/torch/distributed/pipelining/__init__.py @@ -8,7 +8,7 @@ pipeline, SplitPoint, ) -from ._PipelineStage import PipelineStage +from ._PipelineStage import ManualPipelineStage, PipelineStage from .PipelineSchedule import ( Schedule1F1B, ScheduleGPipe, @@ -24,6 +24,7 @@ "pipeline", "ArgsChunkSpec", "KwargsChunkSpec", + "ManualPipelineStage", "PipelineStage", "Schedule1F1B", "ScheduleGPipe", From d30cdc43215d4b6440da4a90d01ca289c221d357 Mon Sep 17 00:00:00 2001 From: Jack Taylor Date: Tue, 21 May 2024 01:59:26 +0000 Subject: [PATCH 30/35] [ROCm] amdsmi library integration (#119182) Adds monitoring support for ROCm using amdsmi in place of pynvml. Pull Request resolved: https://github.com/pytorch/pytorch/pull/119182 Approved by: https://github.com/jeffdaily, https://github.com/malfet, https://github.com/xw285cornell --- .ci/docker/centos-rocm/Dockerfile | 3 + .ci/docker/common/install_amdsmi.sh | 5 + .ci/docker/common/install_rocm.sh | 6 +- .ci/docker/ubuntu-rocm/Dockerfile | 5 + test/test_cuda.py | 5 +- tools/stats/monitor.py | 127 ++++++------------- torch/_dynamo/trace_rules.py | 5 + torch/cuda/__init__.py | 189 +++++++++++++++++++++++++--- torch/cuda/memory.py | 58 ++++++--- 9 files changed, 274 insertions(+), 129 deletions(-) create mode 100644 .ci/docker/common/install_amdsmi.sh diff --git a/.ci/docker/centos-rocm/Dockerfile b/.ci/docker/centos-rocm/Dockerfile index bcf028812a887..6cb82a1f770c5 100644 --- a/.ci/docker/centos-rocm/Dockerfile +++ b/.ci/docker/centos-rocm/Dockerfile @@ -77,6 +77,9 @@ RUN rm install_rocm.sh COPY ./common/install_rocm_magma.sh install_rocm_magma.sh RUN bash ./install_rocm_magma.sh RUN rm install_rocm_magma.sh +COPY ./common/install_amdsmi.sh install_amdsmi.sh +RUN bash ./install_amdsmi.sh +RUN rm install_amdsmi.sh ENV PATH /opt/rocm/bin:$PATH ENV PATH /opt/rocm/hcc/bin:$PATH ENV PATH /opt/rocm/hip/bin:$PATH diff --git a/.ci/docker/common/install_amdsmi.sh b/.ci/docker/common/install_amdsmi.sh new file mode 100644 index 0000000000000..c16c262f0e61f --- /dev/null +++ b/.ci/docker/common/install_amdsmi.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +set -ex + +cd /opt/rocm/share/amd_smi && pip install . diff --git a/.ci/docker/common/install_rocm.sh b/.ci/docker/common/install_rocm.sh index 5659b487f8380..6b746d2f92b48 100644 --- a/.ci/docker/common/install_rocm.sh +++ b/.ci/docker/common/install_rocm.sh @@ -39,7 +39,8 @@ install_ubuntu() { rocm-libs \ rccl \ rocprofiler-dev \ - roctracer-dev + roctracer-dev \ + amd-smi-lib if [[ $(ver $ROCM_VERSION) -ge $(ver 6.1) ]]; then DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated rocm-llvm-dev @@ -106,7 +107,8 @@ install_centos() { rocm-libs \ rccl \ rocprofiler-dev \ - roctracer-dev + roctracer-dev \ + amd-smi-lib # precompiled miopen kernels; search for all unversioned packages # if search fails it will abort this script; use true to avoid case where search fails diff --git a/.ci/docker/ubuntu-rocm/Dockerfile b/.ci/docker/ubuntu-rocm/Dockerfile index 9964f5c3fa91b..cc43d9ec24142 100644 --- a/.ci/docker/ubuntu-rocm/Dockerfile +++ b/.ci/docker/ubuntu-rocm/Dockerfile @@ -78,6 +78,11 @@ ENV MAGMA_HOME /opt/rocm/magma ENV LANG C.UTF-8 ENV LC_ALL C.UTF-8 +# Install amdsmi +COPY ./common/install_amdsmi.sh install_amdsmi.sh +RUN bash ./install_amdsmi.sh +RUN rm install_amdsmi.sh + # (optional) Install non-default CMake version ARG CMAKE_VERSION COPY ./common/install_cmake.sh install_cmake.sh diff --git a/test/test_cuda.py b/test/test_cuda.py index 93e08eff4df6d..c1b990381bbe8 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -4251,7 +4251,10 @@ def free(): @unittest.skipIf(TEST_PYNVML, "pynvml is not available") def test_nvml_get_handler(self): - self.assertTrue(torch.cuda._get_pynvml_handler() is not None) + if not torch.version.hip: + self.assertTrue(torch.cuda._get_pynvml_handler() is not None) + else: + self.assertTrue(torch.cuda._get_amdsmi_handler() is not None) @unittest.skipIf(TEST_PYNVML, "pynvml is not available") def test_temperature(self): diff --git a/tools/stats/monitor.py b/tools/stats/monitor.py index cd43fe5a8e7f1..e5c56f49de3b2 100644 --- a/tools/stats/monitor.py +++ b/tools/stats/monitor.py @@ -2,28 +2,10 @@ import datetime import json import signal -import sys import time from typing import Any, Dict, List import psutil # type: ignore[import] -import pynvml # type: ignore[import] - -# ROCm does not currently have the rocm_smi module installed to a pythonic location. -# Must import from ROCm installation path. -# Cannot use the high-level rocm_smi cmdline module due to its use of exit(). -# Must use the lower-level ctypes wrappers exposed through rsmiBindings. -sys.path.append("/opt/rocm/libexec/rocm_smi") -try: - from ctypes import byref, c_uint32, c_uint64 - - from rsmiBindings import ( # type: ignore[import] - rocmsmi, - rsmi_process_info_t, - rsmi_status_t, - ) -except ImportError as e: - pass def get_processes_running_python_tests() -> List[Any]: @@ -76,78 +58,42 @@ def get_per_process_gpu_info(handle: Any) -> List[Dict[str, Any]]: return per_process_info -def rocm_ret_ok(ret: int) -> Any: - return ret == rsmi_status_t.RSMI_STATUS_SUCCESS - - -def rocm_list_devices() -> List[int]: - num = c_uint32(0) - ret = rocmsmi.rsmi_num_monitor_devices(byref(num)) - if rocm_ret_ok(ret): - return list(range(num.value)) - return [] - - -def rocm_get_mem_use(device: int) -> float: - memoryUse = c_uint64() - memoryTot = c_uint64() - - ret = rocmsmi.rsmi_dev_memory_usage_get(device, 0, byref(memoryUse)) - if rocm_ret_ok(ret): - ret = rocmsmi.rsmi_dev_memory_total_get(device, 0, byref(memoryTot)) - if rocm_ret_ok(ret): - return float(memoryUse.value) / float(memoryTot.value) - return 0.0 - - -def rocm_get_gpu_use(device: int) -> float: - percent = c_uint32() - ret = rocmsmi.rsmi_dev_busy_percent_get(device, byref(percent)) - if rocm_ret_ok(ret): - return float(percent.value) - return 0.0 - - -def rocm_get_pid_list() -> List[Any]: - num_items = c_uint32() - ret = rocmsmi.rsmi_compute_process_info_get(None, byref(num_items)) - if rocm_ret_ok(ret): - buff_sz = num_items.value + 10 - procs = (rsmi_process_info_t * buff_sz)() - procList = [] - ret = rocmsmi.rsmi_compute_process_info_get(byref(procs), byref(num_items)) - for i in range(num_items.value): - procList.append(procs[i].process_id) - return procList - return [] - - -def rocm_get_per_process_gpu_info() -> List[Dict[str, Any]]: +def rocm_get_per_process_gpu_info(handle: Any) -> List[Dict[str, Any]]: + processes = amdsmi.amdsmi_get_gpu_process_list(handle) per_process_info = [] - for pid in rocm_get_pid_list(): - proc = rsmi_process_info_t() - ret = rocmsmi.rsmi_compute_process_info_by_pid_get(int(pid), byref(proc)) - if rocm_ret_ok(ret): - info = {"pid": pid, "gpu_memory": proc.vram_usage} - per_process_info.append(info) + for p in processes: + proc_info = amdsmi.amdsmi_get_gpu_process_info(handle, p) + info = { + "pid": proc_info["pid"], + "gpu_memory": proc_info["memory_usage"]["vram_mem"], + } + per_process_info.append(info) return per_process_info if __name__ == "__main__": handle = None try: - pynvml.nvmlInit() - handle = pynvml.nvmlDeviceGetHandleByIndex(0) - except pynvml.NVMLError: + import pynvml # type: ignore[import] + + try: + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + except pynvml.NVMLError: + pass + except ModuleNotFoundError: # no pynvml avaliable, probably because not cuda pass - - rsmi_handles = [] try: - ret = rocmsmi.rsmi_init(0) - rsmi_handles = rocm_list_devices() - except Exception: - # no rocmsmi available, probably because not rocm + import amdsmi # type: ignore[import] + + try: + amdsmi.amdsmi_init() + amdsmi_handle = amdsmi.amdsmi_get_processor_handles()[0] + except amdsmi.AmdSmiException: + pass + except ModuleNotFoundError: + # no amdsmi is available pass kill_now = False @@ -171,17 +117,16 @@ def exit_gracefully(*args: Any) -> None: gpu_utilization = pynvml.nvmlDeviceGetUtilizationRates(handle) stats["total_gpu_utilization"] = gpu_utilization.gpu stats["total_gpu_mem_utilization"] = gpu_utilization.memory - if rsmi_handles: - stats["per_process_gpu_info"] = rocm_get_per_process_gpu_info() - # There are 1 to 4 GPUs in use; these values may sum > 1.0. - gpu_utilization = 0.0 - gpu_memory = 0.0 - for dev in rsmi_handles: - gpu_utilization += rocm_get_gpu_use(dev) - gpu_memory += rocm_get_mem_use(dev) - stats["total_gpu_utilization"] = gpu_utilization - stats["total_gpu_mem_utilization"] = gpu_memory - + if amdsmi_handle is not None: + stats["per_process_gpu_info"] = rocm_get_per_process_gpu_info( + amdsmi_handle + ) + stats["total_gpu_utilization"] = amdsmi.amdsmi_get_gpu_activity( + amdsmi_handle + )["gfx_activity"] + stats["total_gpu_mem_utilization"] = amdsmi.amdsmi_get_gpu_activity( + amdsmi_handle + )["umc_activity"] except Exception as e: stats = { "time": datetime.datetime.utcnow().isoformat("T") + "Z", diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 9ac3db8647089..b43f447737128 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -2427,7 +2427,10 @@ "torch.cpu.synchronize", "torch.cuda._check_capability", "torch.cuda._check_cubins", + "torch.cuda._device_count_amdsmi", "torch.cuda._device_count_nvml", + "torch.cuda._get_amdsmi_handler", + "torch.cuda._get_amdsmi_device_index", "torch.cuda._get_device", "torch.cuda._get_generator", "torch.cuda._get_nvml_device_index", @@ -2458,7 +2461,9 @@ "torch.cuda._memory_viz.trace", "torch.cuda._nvml_based_avail", "torch.cuda._parse_visible_devices", + "torch.cuda._raw_device_count_amdsmi", "torch.cuda._raw_device_count_nvml", + "torch.cuda._raw_device_uuid_amdsmi", "torch.cuda._raw_device_uuid_nvml", "torch.cuda._register_triton_kernels", "torch.cuda._set_rng_state_offset", diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 1cc88e8adc578..8c19788d1055d 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -53,9 +53,18 @@ _HAS_PYNVML = False _PYNVML_ERR = None try: - import pynvml # type: ignore[import] + try: + import pynvml # type: ignore[import] + + _HAS_PYNVML = True + except ModuleNotFoundError: + pass + try: + import amdsmi # type: ignore[import] - _HAS_PYNVML = True + _HAS_PYNVML = True + except ModuleNotFoundError: + pass except ImportError as err: _PYNVML_ERR = err # sometimes a lib is installed but the import fails for some other reason, so we log the error for later @@ -571,7 +580,9 @@ def set_stream(stream: Stream): def _parse_visible_devices() -> Union[List[int], List[str]]: r"""Parse CUDA_VISIBLE_DEVICES environment variable.""" - var = os.getenv("CUDA_VISIBLE_DEVICES") + var = os.getenv( + "CUDA_VISIBLE_DEVICES" if not torch.version.hip else "HIP_VISIBLE_DEVICES" + ) if var is None: return list(range(64)) @@ -617,6 +628,16 @@ def parse_list_with_prefix(lst: str, prefix: str) -> List[str]: return rc +def _raw_device_count_amdsmi() -> int: + try: + amdsmi.amdsmi_init() + except amdsmi.AmdSmiException as e: + warnings.warn(f"Can't initialize amdsmi - Error code: {e.err_code}") + return -1 + socket_handles = amdsmi.amdsmi_get_processor_handles() + return len(socket_handles) + + def _raw_device_count_nvml() -> int: r"""Return number of devices as reported by NVML or negative value if NVML discovery/initialization failed.""" from ctypes import byref, c_int, CDLL @@ -635,6 +656,36 @@ def _raw_device_count_nvml() -> int: return dev_count.value +def _raw_device_uuid_amdsmi() -> Optional[List[str]]: + from ctypes import byref, c_int, c_void_p, CDLL, create_string_buffer + + try: + amdsmi.amdsmi_init() + except amdsmi.AmdSmiException: + warnings.warn("Can't initialize amdsmi") + return None + try: + socket_handles = amdsmi.amdsmi_get_processor_handles() + dev_count = len(socket_handles) + except amdsmi.AmdSmiException: + warnings.warn("Can't get amdsmi device count") + return None + uuids: List[str] = [] + for idx in range(dev_count): + try: + handler = amdsmi.amdsmi_get_processor_handles()[idx] + except amdsmi.AmdSmiException: + warnings.warn("Cannot get amd device handler") + return None + try: + uuid = amdsmi.amdsmi_get_gpu_device_uuid(handler) + except amdsmi.AmdSmiException: + warnings.warn("Cannot get uuid for amd device") + return None + uuids.append(str(uuid)) + return uuids + + def _raw_device_uuid_nvml() -> Optional[List[str]]: r"""Return list of device UUID as reported by NVML or None if NVM discovery/initialization failed.""" from ctypes import byref, c_int, c_void_p, CDLL, create_string_buffer @@ -694,6 +745,28 @@ def uuid_to_orinal(candidate: str, uuids: List[str]) -> int: return rc +def _device_count_amdsmi() -> int: + visible_devices = _parse_visible_devices() + if not visible_devices: + return 0 + try: + if type(visible_devices[0]) is str: + return -1 + else: + raw_cnt = _raw_device_count_amdsmi() + if raw_cnt <= 0: + return raw_cnt + # Trim the list up to a maximum available device + for idx, val in enumerate(visible_devices): + if cast(int, val) >= raw_cnt: + return idx + except OSError: + return -1 + except AttributeError: + return -1 + return len(visible_devices) + + def _device_count_nvml() -> int: r"""Return number of devices as reported by NVML taking CUDA_VISIBLE_DEVICES into account. @@ -758,7 +831,7 @@ def device_count() -> int: if _cached_device_count is not None: return _cached_device_count # bypass _device_count_nvml() if rocm (not supported) - nvml_count = -1 if torch.version.hip else _device_count_nvml() + nvml_count = _device_count_amdsmi() if torch.version.hip else _device_count_nvml() r = torch._C._cuda_getDeviceCount() if nvml_count < 0 else nvml_count # NB: Do not cache the device count prior to CUDA initialization, because # the number of devices can change due to changes to CUDA_VISIBLE_DEVICES @@ -916,6 +989,68 @@ def _get_pynvml_handler(device: Optional[Union[Device, int]] = None): return handle +def _get_amdsmi_handler(device: Optional[Union[Device, int]] = None): + if not _HAS_PYNVML: + raise ModuleNotFoundError( + "amdsmi does not seem to be installed or it can't be imported." + ) from _PYNVML_ERR + try: + amdsmi.amdsmi_init() + except amdsmi.AmdSmiException as e: + raise RuntimeError( + "amdsmi driver can't be loaded, requires >=ROCm5.6 installation" + ) from e + device = _get_amdsmi_device_index(device) + handle = amdsmi.amdsmi_get_processor_handles()[device] + return handle + + +def _get_amdsmi_device_index(device: Optional[Union[int, Device]]) -> int: + r"""Return the amdsmi index of the device, taking HIP_VISIBLE_DEVICES into account.""" + idx = _get_device_index(device, optional=True) + visible_devices = _parse_visible_devices() + if type(visible_devices[0]) is str: + raise RuntimeError("HIP_VISIBLE_DEVICES should be indices and not strings") + idx_map = dict(enumerate(cast(List[int], visible_devices))) + if idx not in idx_map: + raise RuntimeError( + f"device {idx} is not visible (HIP_VISIBLE_DEVICES={visible_devices})" + ) + return idx_map[idx] + + +def _get_amdsmi_memory_usage(device: Optional[Union[Device, int]] = None) -> int: + handle = _get_amdsmi_handler() + device = _get_amdsmi_device_index(device) + return amdsmi.amdsmi_get_gpu_vram_usage(handle)["vram_used"] + + +def _get_amdsmi_utilization(device: Optional[Union[Device, int]] = None) -> int: + handle = _get_amdsmi_handler() + device = _get_amdsmi_device_index(device) + handle = amdsmi.amdsmi_get_processor_handles()[device] + return amdsmi.amdsmi_get_gpu_activity(handle)["gfx_activity"] + + +def _get_amdsmi_temperature(device: Optional[Union[Device, int]] = None) -> int: + handle = _get_amdsmi_handler(device) + return amdsmi.amdsmi_get_temp_metric( + handle, + amdsmi.AmdSmiTemperatureType.JUNCTION, + amdsmi.AmdSmiTemperatureMetric.CURRENT, + ) + + +def _get_amdsmi_power_draw(device: Optional[Union[Device, int]] = None) -> int: + handle = _get_amdsmi_handler(device) + return amdsmi.amdsmi_get_power_info(handle)["average_socket_power"] + + +def _get_amdsmi_clock_rate(device: Optional[Union[Device, int]] = None) -> int: + handle = _get_amdsmi_handler(device) + return amdsmi.amdsmi_get_clock_info(handle, amdsmi.AmdSmiClkType.GFX)["cur_clk"] + + def memory_usage(device: Optional[Union[Device, int]] = None) -> int: r"""Return the percent of time over the past sample period during which global (device) memory was being read or written as given by `nvidia-smi`. @@ -928,11 +1063,13 @@ def memory_usage(device: Optional[Union[Device, int]] = None) -> int: Warning: Each sample period may be between 1 second and 1/6 second, depending on the product being queried. """ - handle = _get_pynvml_handler() - - device = _get_nvml_device_index(device) - handle = pynvml.nvmlDeviceGetHandleByIndex(device) - return pynvml.nvmlDeviceGetUtilizationRates(handle).memory + if not torch.version.hip: + handle = _get_pynvml_handler() + device = _get_nvml_device_index(device) + handle = pynvml.nvmlDeviceGetHandleByIndex(device) + return pynvml.nvmlDeviceGetUtilizationRates(handle).memory + else: + return _get_amdsmi_memory_usage(device) def utilization(device: Optional[Union[Device, int]] = None) -> int: @@ -947,10 +1084,13 @@ def utilization(device: Optional[Union[Device, int]] = None) -> int: Warning: Each sample period may be between 1 second and 1/6 second, depending on the product being queried. """ - handle = _get_pynvml_handler(device) - device = _get_nvml_device_index(device) - handle = pynvml.nvmlDeviceGetHandleByIndex(device) - return pynvml.nvmlDeviceGetUtilizationRates(handle).gpu + if not torch.version.hip: + handle = _get_pynvml_handler(device) + device = _get_nvml_device_index(device) + handle = pynvml.nvmlDeviceGetHandleByIndex(device) + return pynvml.nvmlDeviceGetUtilizationRates(handle).gpu + else: + return _get_amdsmi_utilization(device) def temperature(device: Optional[Union[Device, int]] = None) -> int: @@ -966,9 +1106,12 @@ def temperature(device: Optional[Union[Device, int]] = None) -> int: Warning: Each sample period may be between 1 second and 1/6 second, depending on the product being queried. """ - handle = _get_pynvml_handler(device) - # 0 refers to the temperature sensor for the GPU die. - return pynvml.nvmlDeviceGetTemperature(handle, 0) + if not torch.version.hip: + handle = _get_pynvml_handler(device) + # 0 refers to the temperature sensor for the GPU die. + return pynvml.nvmlDeviceGetTemperature(handle, 0) + else: + return _get_amdsmi_temperature(device) def power_draw(device: Optional[Union[Device, int]] = None) -> int: @@ -983,8 +1126,11 @@ def power_draw(device: Optional[Union[Device, int]] = None) -> int: Warning: Each sample period may be between 1 second and 1/6 second, depending on the product being queried. """ - handle = _get_pynvml_handler(device) - return pynvml.nvmlDeviceGetPowerUsage(handle) + if not torch.version.hip: + handle = _get_pynvml_handler(device) + return pynvml.nvmlDeviceGetPowerUsage(handle) + else: + return _get_amdsmi_power_draw(device) def clock_rate(device: Optional[Union[Device, int]] = None) -> int: @@ -998,8 +1144,11 @@ def clock_rate(device: Optional[Union[Device, int]] = None) -> int: Warning: Each sample period may be between 1 second and 1/6 second, depending on the product being queried. """ - handle = _get_pynvml_handler(device) - return pynvml.nvmlDeviceGetClockInfo(handle, 1) + if not torch.version.hip: + handle = _get_pynvml_handler(device) + return pynvml.nvmlDeviceGetClockInfo(handle, 1) + else: + return _get_amdsmi_clock_rate(device) def _get_device(device: Union[int, str, torch.device]) -> torch.device: diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 8453842ef14a2..8a5110b10c98b 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -15,7 +15,13 @@ from torch.types import Device from .._utils import _dummy_type -from . import _get_device_index, _get_nvml_device_index, _lazy_init, is_initialized +from . import ( + _get_amdsmi_device_index, + _get_device_index, + _get_nvml_device_index, + _lazy_init, + is_initialized, +) from ._memory_viz import memory as _memory, segments as _segments @@ -609,26 +615,48 @@ def list_gpu_processes(device: Union[Device, int] = None) -> str: printout for the current device, given by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` (default). """ - try: - import pynvml # type: ignore[import] - except ModuleNotFoundError: - return "pynvml module not found, please install pynvml" - from pynvml import NVMLError_DriverNotLoaded + if not torch.version.hip: + try: + import pynvml # type: ignore[import] + except ModuleNotFoundError: + return "pynvml module not found, please install pynvml" + from pynvml import NVMLError_DriverNotLoaded + + try: + pynvml.nvmlInit() + except NVMLError_DriverNotLoaded: + return "cuda driver can't be loaded, is cuda enabled?" + + device = _get_nvml_device_index(device) + handle = pynvml.nvmlDeviceGetHandleByIndex(device) + procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle) + else: + try: + import amdsmi # type: ignore[import] + except ModuleNotFoundError: + return "amdsmi module not found, please install amdsmi" + try: + amdsmi.amdsmi_init() # type: ignore[attr-defined] + except amdsmi.AmdSmiException: # type: ignore[attr-defined] + return "amdsmi driver can't be loaded, is ROCm installed?" + + device = _get_amdsmi_device_index(device) + handle = amdsmi.amdsmi_get_processor_handles()[device] # type: ignore[attr-defined] + procs = amdsmi.amdsmi_get_gpu_process_list(handle) # type: ignore[attr-defined] - try: - pynvml.nvmlInit() - except NVMLError_DriverNotLoaded: - return "cuda driver can't be loaded, is cuda enabled?" - device = _get_nvml_device_index(device) - handle = pynvml.nvmlDeviceGetHandleByIndex(device) - procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle) lines = [] lines.append(f"GPU:{device}") if len(procs) == 0: lines.append("no processes are running") for p in procs: - mem = p.usedGpuMemory / (1024 * 1024) - lines.append(f"process {p.pid:>10d} uses {mem:>12.3f} MB GPU memory") + if not torch.version.hip: + mem = p.usedGpuMemory / (1024 * 1024) + pid = p.pid + else: + proc_info = amdsmi.amdsmi_get_gpu_process_info(handle, p) # type: ignore[possibly-undefined] + mem = proc_info["memory_usage"]["vram_mem"] / (1024 * 1024) + pid = proc_info["pid"] + lines.append(f"process {pid:>10d} uses {mem:>12.3f} MB GPU memory") return "\n".join(lines) From d777685ef913d9251ab5739827e64af91e5a2e55 Mon Sep 17 00:00:00 2001 From: eellison Date: Fri, 17 May 2024 13:00:12 -0700 Subject: [PATCH 31/35] Script for choosing template configurations (#126560) This adds logging that will mark any invocation of a matmul for a particular input shapes, and record every template configs performance on it. Then, we can parse that into a script which will minimize the total mm execution time given N allowed templates. And in future, other experiments.. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126560 Approved by: https://github.com/nmacchioni, https://github.com/jansel --- .../microbenchmarks/analyze_templates.py | 219 ++++++++++++++++++ torch/_inductor/select_algorithm.py | 74 ++++++ 2 files changed, 293 insertions(+) create mode 100644 benchmarks/dynamo/microbenchmarks/analyze_templates.py diff --git a/benchmarks/dynamo/microbenchmarks/analyze_templates.py b/benchmarks/dynamo/microbenchmarks/analyze_templates.py new file mode 100644 index 0000000000000..65fa547123a4b --- /dev/null +++ b/benchmarks/dynamo/microbenchmarks/analyze_templates.py @@ -0,0 +1,219 @@ +""" +This script uses linear programming to analyze outputs of triton mm config tuning. +To generate output that can be fed into this script set the env varTORCHINDUCTOR_MM_LOGGING_FILE. + +That file can be fed into this script to generate the minimizes total, weighted matmul time as a function of allowed templates. +""" +import json + +import click +import pulp + + +def parse_log_file(file_path): + with open(file_path) as f: + logs = json.load(f) + + occurrence_count = {} + benchmark_logs = {} + + # Parse the logs + for entry in logs: + if "invoke" in entry: + shape = entry["invoke"] + if shape not in occurrence_count: + occurrence_count[shape] = 0 + occurrence_count[shape] += 1 + else: + for shape, timings in entry.items(): + if shape not in benchmark_logs: + benchmark_logs[shape] = [] + benchmark_logs[shape].extend(timings) + + return occurrence_count, benchmark_logs + + +def optimize_templates(N, occurrence_count, benchmark_logs, verbose=False): + # Set of all possible Triton templates keyed by their attributes + triton_templates = set() + for timings in benchmark_logs.values(): + for timing in timings: + if timing["type"] == "triton": + triton_templates.add( + ( + timing["BLOCK_M"], + timing["BLOCK_N"], + timing["BLOCK_K"], + timing["num_stages"], + timing["num_warps"], + ) + ) + + # Print the initial data + if verbose: + print("Occurrence Count:", occurrence_count) + print("Triton Templates:", triton_templates) + + # Create a dictionary to store template selection variables + template_vars = { + template: pulp.LpVariable(f"Template_{template}", 0, 1, pulp.LpBinary) + for template in triton_templates + } + + # Variables to select specific timing option for each shape + selection_vars = { + (shape, "cublas"): pulp.LpVariable( + f"Select_{shape}_cublas", 0, 1, pulp.LpBinary + ) + for shape in occurrence_count + } + for shape in occurrence_count: + for template in triton_templates: + selection_vars[(shape, template)] = pulp.LpVariable( + f"Select_{shape}_{template}", 0, 1, pulp.LpBinary + ) + + # Variables for the total time for each shape + min_time_vars = pulp.LpVariable.dicts( + "MinTime", occurrence_count.keys(), 0, None, pulp.LpContinuous + ) + + # Define the problem + prob = pulp.LpProblem("MatrixMultiplicationOptimization", pulp.LpMinimize) + + # Objective: Minimize the weighted total time + prob += pulp.lpSum( + [occurrence_count[shape] * min_time_vars[shape] for shape in occurrence_count] + ) + + # Constraints to select exactly N templates + prob += pulp.lpSum([template_vars[template] for template in triton_templates]) == N + + # Store triton options per shape for debugging + triton_options_per_shape = {} + + # Constraints for the total time for each shape + for shape in occurrence_count: + # Get cuBLAS time + cublas_times = [ + timing["time"] + for timing in benchmark_logs[shape] + if timing["type"] == "cublas" + ] + min_cublas_time = min(cublas_times) + + # Collect Triton options + triton_options = [] + for template in triton_templates: + triton_times = [ + timing["time"] + for timing in benchmark_logs[shape] + if timing["type"] == "triton" + and ( + timing["BLOCK_M"], + timing["BLOCK_N"], + timing["BLOCK_K"], + timing["num_stages"], + timing["num_warps"], + ) + == template + ] + if triton_times: + min_triton_time = min(triton_times) + triton_options.append((min_triton_time, template)) + + # Save triton options for debugging + triton_options_per_shape[shape] = triton_options + + # Ensure exactly one timing option is selected for each shape + prob += ( + pulp.lpSum( + [selection_vars[(shape, "cublas")]] + + [ + selection_vars[(shape, template)] + for triton_time, template in triton_options + ] + ) + == 1 + ) + + # Ensure min_time_vars[shape] matches the selected timing option + prob += min_time_vars[shape] == ( + selection_vars[(shape, "cublas")] * min_cublas_time + + pulp.lpSum( + [ + selection_vars[(shape, template)] * triton_time + for triton_time, template in triton_options + ] + ) + ) + + # Ensure Triton templates can only be selected if they are included in the N allowed templates + for triton_time, template in triton_options: + prob += selection_vars[(shape, template)] <= template_vars[template] + + # Print the constraints + if verbose: + print("Constraints:") + for constraint in prob.constraints.values(): + print(constraint) + + # Solve the problem with suppressed output + prob.solve(pulp.PULP_CBC_CMD(msg=False)) + + # Output the selected templates and their configurations + selected_templates = [ + template + for template in triton_templates + if pulp.value(template_vars[template]) == 1 + ] + total_time = sum( + pulp.value(min_time_vars[shape]) * occurrence_count[shape] + for shape in occurrence_count + ) + + # Print the values of the decision variables after solving + if verbose: + print("Decision Variable Values:") + for var in prob.variables(): + print(f"{var.name} = {var.varValue}") + + # # Debugging information + if verbose: + for shape in occurrence_count: + print(f"Shape: {shape}") + print(f" Min Time: {pulp.value(min_time_vars[shape])}") + print(f" Occurrences: {occurrence_count[shape]}") + print( + f" Min CuBLAS Time: {min_cublas_time} Selected: {pulp.value(selection_vars[(shape, 'cublas')])}" + ) + for triton_time, template in triton_options_per_shape[shape]: + print( + f" Triton Template: {template} Time: {triton_time} Selected: {pulp.value(selection_vars[(shape, template)])}" + ) + + return selected_templates, total_time + + +# Main code to parse the log file and optimize templates +@click.command() +@click.argument("filename") +@click.option("--min-templates", default=0, help="Minimum number of templates.") +@click.option("--max-templates", default=10, help="Maximum number of templates.") +@click.option("--verbose", is_flag=True, help="Enable verbose output.") +def main(filename, min_templates, max_templates, verbose): + occurrence_count, benchmark_logs = parse_log_file(filename) + times = [] + for N in range(min_templates, max_templates + 1): + selected_templates, total_time = optimize_templates( + N, occurrence_count, benchmark_logs, verbose + ) + print(f"N = {N}") + print(f"Selected Templates: {selected_templates}") + print(f"Total Weighted Time: {total_time}") + times.append(total_time) + print(times) + + +if __name__ == "__main__": + main() diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 5cb10e1820cf9..f843a98039a82 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -2,6 +2,7 @@ import functools import inspect import itertools +import json import logging import math @@ -17,6 +18,7 @@ from unittest.mock import patch import sympy +from filelock import FileLock import torch from torch._dynamo.testing import rand_strided @@ -908,6 +910,34 @@ def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType } +@functools.lru_cache(None) +def get_mm_log_filename() -> Optional[str]: + mm_file_name = os.environ.get("TORCHINDUCTOR_MM_LOGGING_FILE", None) + if not mm_file_name: + return None + + if "json" not in mm_file_name: + mm_file_name = f"{mm_file_name}.json" + + return mm_file_name + + +def append_to_log(filename, data): + lock_file = filename.replace(".json", ".lock") + lock = FileLock(lock_file) + with lock: + try: + with open(filename) as f: + log_data = json.load(f) + except (FileNotFoundError, json.JSONDecodeError): + log_data = [] + + log_data.append(data) + + with open(filename, "w") as f: + json.dump(log_data, f, indent=4) + + class ErrorFromChoice(RuntimeError): def __init__(self, msg, choice: ChoiceCaller, inputs_str): msg += f"\nFrom choice {choice}\n{inputs_str}" @@ -963,6 +993,11 @@ def __call__( # TODO(nmacchioni): remove once CI tests are fixed choices = [choice for choice in choices if choice is not None] + if mm_file_name := get_mm_log_filename(): + M, K = input_nodes[-2].get_size()[:2] + N = input_nodes[-1].get_size()[-1] + append_to_log(mm_file_name, {"invoke": str((M, K, N))}) + if len(choices) == 0: raise NoValidChoicesError( "No choices to select, please consider adding ATEN into max_autotune_gemm_backends " @@ -1344,9 +1379,48 @@ def log_results( for n in input_nodes ] ) + n = None if log.getEffectiveLevel() == logging.DEBUG else 10 top_k = sorted(timings, key=timings.__getitem__)[:n] best = top_k[0] + + def get_choice_info(choice): + if isinstance(choice, torch._inductor.select_algorithm.ExternKernelCaller): + return {"type": "cublas", "time": timings[choice]} + + assert isinstance( + choice, torch._inductor.select_algorithm.TritonTemplateCaller + ) + + info = choice.info_dict() + tile = info["tile_shape"] + + tile_vals = eval(tile) # type: ignore[arg-type] + BLOCK_M = tile_vals[0] + BLOCK_K = tile_vals[1] + BLOCK_N = tile_vals[2] + + return { + "type": "triton", + "time": timings[choice], + "BLOCK_M": BLOCK_M, + "BLOCK_K": BLOCK_K, + "BLOCK_N": BLOCK_N, + "num_stages": info["num_stages"], + "num_warps": info["num_warps"], + } + + mm_filename = get_mm_log_filename() + if mm_filename and "mm" in name: + M, K = input_nodes[-2].get_size()[:2] + N = input_nodes[-1].get_size()[-1] + + out_dict = { + str((M, K, N)): [get_choice_info(choice) for choice in timings.keys()] + } + + append_to_log(mm_filename, out_dict) + best_time = timings[best] sys.stderr.write(f"AUTOTUNE {name}({sizes})\n") for choice in top_k: From 51c07f9f69aedf884fc697f3ef8545cb0303e2a9 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Mon, 20 May 2024 14:19:34 +0100 Subject: [PATCH 32/35] [dynamo] Allow asserts to fail (#126661) Currently if an assertion is statically known to be false, dynamo converts it to `_assert_async` which inductor currently ignores. Instead this graph breaks to raise the original assertion. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126661 Approved by: https://github.com/ezyang --- test/dynamo/test_misc.py | 15 +++++++++++++++ .../TestPythonRegistration.test_alias_analysis | 0 .../TestScript.test_unspecialized_any_binding | 0 torch/_dynamo/symbolic_convert.py | 8 +++++--- 4 files changed, 20 insertions(+), 3 deletions(-) delete mode 100644 test/dynamo_expected_failures/TestPythonRegistration.test_alias_analysis delete mode 100644 test/dynamo_expected_failures/TestScript.test_unspecialized_any_binding diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 5d7f780457d09..f07021c315585 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -1358,6 +1358,21 @@ def f(x): self.assertRaises(torch._dynamo.exc.UserError, lambda: f(torch.tensor([3]))) + def test_assert(self): + @torch.compile + def fn1(x): + assert x.shape != x.shape + + with self.assertRaises(AssertionError): + a = torch.randn(10) + fn1(a) + + def fn2(x): + assert x.shape == x.shape + return x.abs() + + torch._dynamo.testing.standard_test(self, fn=fn2, nargs=1, expected_ops=1) + def test_config_obj(self): class Cfg: def __init__(self): diff --git a/test/dynamo_expected_failures/TestPythonRegistration.test_alias_analysis b/test/dynamo_expected_failures/TestPythonRegistration.test_alias_analysis deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestScript.test_unspecialized_any_binding b/test/dynamo_expected_failures/TestScript.test_unspecialized_any_binding deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 093809703405f..864e53777941e 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -343,9 +343,11 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction): ): error_msg: VariableTracker = self.pop() # Skip over things like `assert True` - if value.is_python_constant() and bool(value.as_python_constant()): - self.jump(inst) - return + if value.is_python_constant(): + if bool(value.as_python_constant()): + return self.jump(inst) + else: + jump_graph_break(self, inst, value) # TODO maybe should respect DtoH sync intention of users later?? # Manually insert torch._assert_async instead of python assert and jump over From 40cc616909ba35dc68e8d11d4570091a4b723a5a Mon Sep 17 00:00:00 2001 From: cdzhan Date: Tue, 21 May 2024 03:20:13 +0000 Subject: [PATCH 33/35] =?UTF-8?q?Fix=20caching=20allocator=20of=20out-of-t?= =?UTF-8?q?ree=20device=20is=20destructed=20before=20the=20=E2=80=A6=20(#1?= =?UTF-8?q?26677)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …destruction of tensors cached by autocast ## Root Cause For out-of-tree device extension it is loaded after torch (different .so), so the global variable `cached_casts` may be constructed before caching allocator and then destructed in reversed order when exit. ## Fix Lazily initialize `cached_casts` to correct the order. ## How to Reproduce && Test Modify the testcase `TestAutocastGPU.test_cast_cache_is_global` in test/test_autocast.py to run on your out-of-tree device. You will see following failure in the end of test. ```bash ---------------------------------------------------------------------- Ran 1 test in 4.812s OK free: 0x30080ff44000400 terminate called after throwing an instance of 'c10::Error' what(): invalid device pointer: 0x30080ff44000400 Exception raised from free at /projs/framework/betterman/code/pytorch_new/catch/torch_mlu/csrc/framework/core/caching_allocator.cpp:1609 (most recent call first): frame #0: + 0x118fe1 (0x7ffaef4d3fe1 in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #1: + 0x11b1c4 (0x7ffaef4d61c4 in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #2: + 0x117677 (0x7ffaef4d2677 in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #3: + 0x11a2bf (0x7ffaef4d52bf in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #4: + 0x11a186 (0x7ffaef4d5186 in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #5: + 0x119fde (0x7ffaef4d4fde in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #6: + 0x119d2e (0x7ffaef4d4d2e in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #7: + 0x119be0 (0x7ffaef4d4be0 in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #8: + 0x119977 (0x7ffaef4d4977 in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #9: + 0x119313 (0x7ffaef4d4313 in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #10: + 0x118b4c (0x7ffaef4d3b4c in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #11: c10::Error::Error(c10::SourceLocation, std::string) + 0x34 (0x7ffaef4d27c4 in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #12: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x7f (0x7ffaef4d04ed in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #13: torch_mlu::MLUCachingAllocator::Native::NativeCachingAllocator::free(void*) + 0xe6 (0x7ff9a8eeb112 in /projs/framework/betterman/code/pytorch_new/catch/torch_mlu/csrc/lib/libtorch_mlu.so) frame #14: torch_mlu::MLUCachingAllocator::Native::local_raw_delete(void*) + 0x3b (0x7ff9a8ed9480 in /projs/framework/betterman/code/pytorch_new/catch/torch_mlu/csrc/lib/libtorch_mlu.so) frame #15: std::unique_ptr::~unique_ptr() + 0x50 (0x7ffb0a5ea322 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_python.so) frame #16: + 0x1269890 (0x7ffb0a5e4890 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_python.so) frame #17: + 0x1269928 (0x7ffb0a5e4928 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_python.so) frame #18: + 0x127572c (0x7ffb0a5f072c in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_python.so) frame #19: + 0x1275758 (0x7ffb0a5f0758 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_python.so) frame #20: + 0xb9bc7 (0x7ffaef474bc7 in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #21: + 0xb97bc (0x7ffaef4747bc in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #22: + 0xdbc50 (0x7ffaef496c50 in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #23: c10::TensorImpl::~TensorImpl() + 0x82 (0x7ffaef49157e in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #24: c10::TensorImpl::~TensorImpl() + 0x1c (0x7ffaef4915aa in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #25: + 0x2f596d9 (0x7ffaf24fc6d9 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame #26: + 0x2f589c2 (0x7ffaf24fb9c2 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame #27: + 0x2f57b92 (0x7ffaf24fab92 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame #28: + 0x2f5c228 (0x7ffaf24ff228 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame #29: + 0x30f3f70 (0x7ffaf2696f70 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame #30: + 0x30f3f90 (0x7ffaf2696f90 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame #31: + 0x30f5004 (0x7ffaf2698004 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame #32: + 0x30f5024 (0x7ffaf2698024 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame #33: + 0x31207f0 (0x7ffaf26c37f0 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame #34: + 0x3120814 (0x7ffaf26c3814 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame #35: + 0x30f51e8 (0x7ffaf26981e8 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame #36: + 0x30f5148 (0x7ffaf2698148 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame #37: + 0x316ecea (0x7ffaf2711cea in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame #38: + 0x468a7 (0x7ffb0c9ed8a7 in /lib/x86_64-linux-gnu/libc.so.6) frame #39: on_exit + 0 (0x7ffb0c9eda60 in /lib/x86_64-linux-gnu/libc.so.6) frame #47: __libc_start_main + 0xf3 (0x7ffb0c9cb083 in /lib/x86_64-linux-gnu/libc.so.6) Aborted (core dumped) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/126677 Approved by: https://github.com/ezyang --- aten/src/ATen/autocast_mode.cpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index 2d01bdeca500b..f0c73cde2dda3 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -35,7 +35,11 @@ namespace { // directly against incoming TensorImpl*s. using weakref_type = c10::weak_intrusive_ptr; using val_type = std::tuple; -ska::flat_hash_map cached_casts; + +static ska::flat_hash_map& get_cached_casts() { + static ska::flat_hash_map cached_casts; + return cached_casts; +} std::mutex cached_casts_mutex; @@ -82,7 +86,7 @@ thread_local bool cache_enabled = true; void clear_cache() { const std::lock_guard lock(cached_casts_mutex); - cached_casts.clear(); + get_cached_casts().clear(); } int increment_nesting() { @@ -124,12 +128,12 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg, DeviceType device_ if (can_try_cache) { const std::lock_guard lock(cached_casts_mutex); - auto it = cached_casts.find(arg.unsafeGetTensorImpl()); - if (it != cached_casts.end()) { + auto it = get_cached_casts().find(arg.unsafeGetTensorImpl()); + if (it != get_cached_casts().end()) { return std::get<1>(it->second); } else { auto casted_arg = arg.to(to_type); - cached_casts.emplace(arg.unsafeGetTensorImpl(), val_type{weakref_type(arg.getIntrusivePtr()), casted_arg}); + get_cached_casts().emplace(arg.unsafeGetTensorImpl(), val_type{weakref_type(arg.getIntrusivePtr()), casted_arg}); return casted_arg; } } else { From 5f64086d08434cebd2bcbbbe335e42d5c1079ac9 Mon Sep 17 00:00:00 2001 From: eqy Date: Tue, 21 May 2024 03:25:27 +0000 Subject: [PATCH 34/35] [NT][SDPA] Bump tolerances for `test_sdpa_with_packed_in_proj_cuda_bfloat16` (#126356) Current tolerances fail on RTX 6000 (Ada) with `Mismatched elements: 2 / 144 (1.4%)` ``` AssertionError: Tensor-likes are not close! Mismatched elements: 2 / 144 (1.4%) Greatest absolute difference: 0.002197265625 at index (5, 0, 0) (up to 0.001 allowed) Greatest relative difference: 0.08203125 at index (3, 0, 0) (up to 0.016 allowed) To execute this test, run the following from the base repo dir: python test/test_nestedtensor.py -k test_sdpa_with_packed_in_proj_cuda_bfloat16 This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0 ---------------------------------------------------------------------- ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/126356 Approved by: https://github.com/drisspg --- test/test_nestedtensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 842033d005f09..597180129f727 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -4246,7 +4246,7 @@ def in_proj(input_packed, qkv_linear=qkv_linear): # Low Precision Math Reference out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math( q, k, v)[0].transpose(-2, -3) - output_ref_atol, output_ref_rtol = get_tolerances(out, out_lp_ref) + output_ref_atol, output_ref_rtol = get_tolerances(out, out_lp_ref, fudge_factor=2) self.assertEqual(out, out_component, atol=output_ref_atol, rtol=output_ref_rtol) From a83e745356fbb0adfd90d8300292a2df00765cef Mon Sep 17 00:00:00 2001 From: Chirag Pandya Date: Mon, 20 May 2024 17:10:51 -0700 Subject: [PATCH 35/35] [BE] split seq_id to collective_seq_id and p2p_seq_id (#125727) Summary: Split out `seq_id` into `collective_seq_id` and `p2p_seq_id`. The main idea here is that collectives that go to all machines should have identical `collective_seq_id` and therefore it makes it easier to spot if one of machines isn't handling a collective operation. Next, we can attempt to match up p2p operations to ensure that the sender(s)/receivers(s) are in sync. Resolves issue: https://github.com/pytorch/pytorch/issues/125173 Test Plan: Unit tests. Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/125727 Approved by: https://github.com/zdevito --- test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp | 4 +- test/distributed/test_c10d_nccl.py | 32 ++++++----- .../distributed/c10d/ProcessGroupNCCL.cpp | 55 ++++++++++++------- .../distributed/c10d/ProcessGroupNCCL.hpp | 11 ++-- torch/csrc/distributed/c10d/TraceUtils.h | 45 +++++++++++---- 5 files changed, 94 insertions(+), 53 deletions(-) diff --git a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp index aef97daae2e47..d83a9494112c6 100644 --- a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp +++ b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp @@ -69,7 +69,7 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL { const std::vector& outputs = {}, bool record = false) override { return c10::make_intrusive( - device, simulateError_, rank, opType, seq_); + device, simulateError_, rank, opType, seqCollective_); } size_t getNCCLCommCacheSize() { @@ -131,7 +131,7 @@ class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors { const std::vector& outputs = {}, bool record = false) override { return c10::make_intrusive( - device, setTimedoutError_, rank, opType, seq_); + device, setTimedoutError_, rank, opType, seqCollective_); } void setTimedoutError() { diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 5a958acdbdd74..e71bfb52b2254 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -3524,7 +3524,7 @@ def test_short(self, timing_enabled): t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) ver = t["version"] - self.assertEqual(ver, "1.5") + self.assertEqual(ver, "2.0") pg_config = t["pg_config"] self.assertEqual(len(pg_config), 1) default_pg_info = pg_config["0"] @@ -3548,7 +3548,7 @@ def test_short(self, timing_enabled): self.assertIn("test_c10d_nccl.py", str(last["frames"])) self.assertEqual(last["input_sizes"], ((3, 4),)) self.assertEqual(last["output_sizes"], ((3, 4),)) - self.assertEqual(last["seq_id"], 2) + self.assertEqual(last["collective_seq_id"], 2) now = datetime.now() event_created_time = datetime.fromtimestamp( last["time_created_ns"] / 1000000000 @@ -3629,7 +3629,7 @@ def test_long(self): self.assertIn("test_c10d_nccl.py", str(last["frames"])) self.assertEqual(last["input_sizes"], ((3, 4),)) self.assertEqual(last["output_sizes"], ((3, 4),)) - self.assertEqual(last["seq_id"] - first["seq_id"], 9) + self.assertEqual(last["collective_seq_id"] - first["collective_seq_id"], 9) @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @@ -3659,10 +3659,10 @@ def test_trace_while_active(self, timing_enabled): t = t["entries"] self.assertEqual(t[-1]["profiling_name"], "nccl:all_reduce") if self.rank == 0: - self.assertEqual(t[-1]["seq_id"], 1) + self.assertEqual(t[-1]["collective_seq_id"], 1) self.assertEqual(t[-1]["state"], "completed") else: - self.assertEqual(t[-1]["seq_id"], 2) + self.assertEqual(t[-1]["collective_seq_id"], 2) self.assertEqual( t[-1]["state"], self.started_or_scheduled(timing_enabled) ) @@ -3704,10 +3704,10 @@ def gather_trace(): t = t["entries"] self.assertEqual(t[-1]["profiling_name"], "nccl:all_reduce") if self.rank == 0: - self.assertEqual(t[-1]["seq_id"], 1) + self.assertEqual(t[-1]["collective_seq_id"], 1) self.assertEqual(t[-1]["state"], "completed") else: - self.assertEqual(t[-1]["seq_id"], 2) + self.assertEqual(t[-1]["collective_seq_id"], 2) self.assertEqual( t[-1]["state"], self.started_or_scheduled(timing_enabled) ) @@ -3799,7 +3799,9 @@ def test_batched_send_recv(self, op_sizes_per_coalesce, timing_enabled): self.assertEqual( t["entries"][p2p_op_idx]["profiling_name"], profiling_name ) - self.assertEqual(t["entries"][p2p_op_idx]["seq_id"], expected_seq) + self.assertEqual( + t["entries"][p2p_op_idx]["collective_seq_id"], expected_seq + ) self.assertEqual(t["entries"][p2p_op_idx]["op_id"], expected_op_id) expected_op_id += 1 self.assertEqual(t["entries"][p2p_op_idx]["input_sizes"], [input_sizes]) @@ -3819,7 +3821,9 @@ def test_batched_send_recv(self, op_sizes_per_coalesce, timing_enabled): self.assertEqual( t["entries"][coalesced_op]["profiling_name"], "nccl:coalesced" ) - self.assertEqual(t["entries"][coalesced_op]["seq_id"], expected_seq) + self.assertEqual( + t["entries"][coalesced_op]["collective_seq_id"], expected_seq + ) expected_seq += 1 self.assertEqual(t["entries"][coalesced_op]["state"], "completed") self.assertEqual(t["entries"][coalesced_op]["input_sizes"], []) @@ -3875,7 +3879,7 @@ def test_individual_send_recv(self, op_sizes, timing_enabled): input_sizes = op_sizes[seq % ops_per_repeat] profiling_name = "nccl:recv 0<-1" if self.rank == 0 else "nccl:send 1->0" self.assertEqual(t["entries"][seq]["profiling_name"], profiling_name) - self.assertEqual(t["entries"][seq]["seq_id"], expected_seq) + self.assertEqual(t["entries"][seq]["p2p_seq_id"], expected_seq) expected_seq += 1 self.assertEqual(t["entries"][seq]["op_id"], expected_op_id) expected_op_id += 1 @@ -3935,7 +3939,7 @@ def test_coalescing_manager_collective(self, timing_enabled): self.assertEqual( t["entries"][0]["profiling_name"], "nccl:reduce_scatter_tensor_coalesced" ) - self.assertEqual(t["entries"][0]["seq_id"], 1) + self.assertEqual(t["entries"][0]["collective_seq_id"], 1) self.assertEqual(t["entries"][0]["input_sizes"], [[2, 2], [2, 2]]) self.assertEqual( t["entries"][0]["output_sizes"], @@ -4003,9 +4007,9 @@ def test_timeout_dumps(self, timing_enabled): t = pickle.load(f) t = t["entries"] self.assertEqual(len(t), 2) - self.assertEqual(t[0]["seq_id"], 1) + self.assertEqual(t[0]["collective_seq_id"], 1) self.assertEqual(t[0]["state"], "completed") - self.assertEqual(t[1]["seq_id"], 2) + self.assertEqual(t[1]["collective_seq_id"], 2) self.assertEqual( t[1]["state"], self.started_or_scheduled(timing_enabled) ) @@ -4066,7 +4070,7 @@ def test_timeout_dumps_on_stuck_ranks(self): t = pickle.load(f) t = t["entries"] self.assertEqual(len(t), 1) - self.assertEqual(t[0]["seq_id"], 1) + self.assertEqual(t[0]["collective_seq_id"], 1) self.assertEqual(t[0]["state"], "completed") return diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 7586058475ff1..2319db06db643 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -931,7 +931,7 @@ void ProcessGroupNCCL::setSequenceNumberForGroup() { } // NCCL just starts sequence numbers at 0. uint64_t ProcessGroupNCCL::getSequenceNumberForGroup() { - return seq_; + return seqCollective_; } void ProcessGroupNCCL::registerOnCompletionHook( @@ -2246,7 +2246,7 @@ c10::intrusive_ptr ProcessGroupNCCL::initWork( device, rank, opType, - seq_, + seqCollective_, profilingTitle, profilingTitle != nullptr ? std::optional>(inputs) : c10::nullopt, @@ -2254,6 +2254,7 @@ c10::intrusive_ptr ProcessGroupNCCL::initWork( enableTiming_.load(), dist_debug_level_); if (record) { + bool isP2P = isP2POp(opType); // Ideally record every work that we enqueue, rather than every work we // create. // - at the time of this PR we do not currently enqueue every created work @@ -2270,13 +2271,15 @@ c10::intrusive_ptr ProcessGroupNCCL::initWork( r->trace_id_ = NCCLTraceBuffer::get()->record( uid_, std::make_tuple(pg_name_, pg_desc_), - seq_, + seqCollective_, + seqP2P_, op_id_, profilingTitle ? profilingTitle : "", inputs, outputs, r->ncclStartEvent_.get(), - r->ncclEndEvent_.get()); + r->ncclEndEvent_.get(), + isP2P); } return r; } @@ -2328,10 +2331,6 @@ ProcessGroupNCCL::Options::Options(bool is_high_priority_stream) static constexpr int CoalActive = 0x01, CoalColl = 0x02, CoalP2P = 0x04; void ProcessGroupNCCL::startCoalescing() { - coalescedDevice_.set_index(-1); - coalescedComm_ = nullptr; - coalescing_state_ |= CoalActive; - groupStart(); // Other collective ops bump seq_ before creating a work. Thus, if coalesced // ops bump seq_ only after initing a work they will collide with (reuse) the // seq_ of the last non-coalesced collective. Previously, seq_ was bumped @@ -2340,10 +2339,19 @@ void ProcessGroupNCCL::startCoalescing() { // same seq_ for those ops and its 'endCoalescing' op. Hence we bump during // start, which has one minor downside- we burn a seq_ if someone ever does a // 'start' and 'end' coalescing region without doing an operation inbetween. - seq_++; - // Don't bump op_id_ here, becuase startCoalescing isn't a logical operation. + // Don't bump op_id_ here, because startCoalescing isn't a logical operation. // Bump it for each logical op inside the coalescing group. + if (coalescing_state_ & CoalP2P) { + seqP2P_++; + } else { + seqCollective_++; + } + + coalescedDevice_.set_index(-1); + coalescedComm_ = nullptr; + coalescing_state_ |= CoalActive; + groupStart(); } // `optype` is for specifying a composite optype, such as ALLGATHER and @@ -2441,7 +2449,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( errorIfCapturingNonCapturableNCCL(capture_status); // Bump collective counter - seq_++; + seqCollective_++; op_id_++; auto device = getDevice(input); @@ -2596,9 +2604,10 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( errorIfCapturingNonCapturableNCCL(capture_status); // Bump collective counter - seq_++; + seqCollective_++; + // For coalescingManager collectives, there is no individual c++ call per - // collective so there is no flight record and we increment seq_ and op_id_ + // collective so there is no flight record and we increment seq*_ and op_id_ // together. Compare this to startCoalesing/endCoalescing flow where we // increment seq_ once per group and increment op_id_ once per indvidual // operation within the group @@ -2826,9 +2835,9 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( p2pTargetRank = isSendRecvSelf ? 0 : 1 - p2pRank; if (!coalescing_state_) { - // Bump sequence number. Don't do so if it's a batch P2P, it will be - // bumped in `endCoalescing`. - seq_++; + // Bump P2P sequence number. Don't do so if it's a batch P2P, it will be + // bumped in `startCoalescing`. + seqP2P_++; } } @@ -2869,13 +2878,15 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( auto trace_id = NCCLTraceBuffer::get()->record( uid_, std::make_tuple(pg_name_, pg_desc_), - seq_, + seqCollective_, + seqP2P_, op_id_, profilingTitle, {tensor}, {tensor}, nullptr, - nullptr); + nullptr, + /*isP2P=*/true); // TODO(whc) if we want to make the per-p2p-op flightrecorder entries get // their timings/states updated by proxy when the Work obj representing the // coalesce group gets its update, we could accumulate these trace_ids @@ -2894,19 +2905,21 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( // output, not sure what work->outputs_ = std::make_shared>(); work->outputs_->push_back(tensor); - // TODO(whc) becuase we don't pass output {tensor} to initWork, we tell + // TODO(whc) because we don't pass output {tensor} to initWork, we tell // initWork to not record, and then we manually call record passing all the // information it wants. work->trace_id_ = NCCLTraceBuffer::get()->record( uid_, std::make_tuple(pg_name_, pg_desc_), - seq_, + seqCollective_, + seqP2P_, op_id_, profilingTitle, {tensor}, {tensor}, work->ncclStartEvent_.get(), - work->ncclEndEvent_.get()); + work->ncclEndEvent_.get(), + /*isP2P=*/true); } // is gpuGuard needed for the if block below, or can i swap them diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 07f3730b1338b..995ae003a1cf0 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -1055,13 +1055,16 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Counting for the sequential number of NCCL collective call. // (specifically, how many actual kernels we launched, which differs from // op_id_ when coalescing is enabled) - uint64_t seq_{0}; + uint64_t seqCollective_{0}; + + // Counting for the sequential number of NCCL P2P calls. + uint64_t seqP2P_{0}; // Incrementing counter for logical operations (collective or p2p) issued on // the ProcessGroup uint64_t op_id_{0}; - // the sequential number of the last colletive enqueued into workMetaList_ + // the sequential number of the last collective enqueued into workMetaList_ // This is useful for indentifying a rank that has not join a collective // initialized to be -1 to indicate no collective has been enqueued int64_t lastEnqueuedSeq_{-1}; @@ -1069,10 +1072,10 @@ class TORCH_API ProcessGroupNCCL : public Backend { // the name of the last collective enqueued into workMetaList_ std::string lastEnqueuedWorkName_; - // the sequential number of the last colletive started as the kernal + // the sequential number of the last collective started as the kernel int64_t lastStartedSeq_{-1}; - // the name of the last collective started as the kernal + // the name of the last collective started as the kernel std::string lastStartedWorkName_; // the sequential number of the last colletive completed marked by diff --git a/torch/csrc/distributed/c10d/TraceUtils.h b/torch/csrc/distributed/c10d/TraceUtils.h index 181f2208160b7..5b2fcc45c8f3f 100644 --- a/torch/csrc/distributed/c10d/TraceUtils.h +++ b/torch/csrc/distributed/c10d/TraceUtils.h @@ -2,17 +2,23 @@ #include #include +#include #include #include +#include + +#include #include #include #include #include +#include #include #include #include + namespace c10d { static c10::IValue entries_key = "entries"; @@ -20,12 +26,14 @@ static c10::IValue nccl_comm_key = "nccl_comm_state"; static c10::IValue version_key = "version"; // Update whenever changing contents or formatting of the dump // (minor when adding fields, major when changing existing fields) -static c10::IValue version_val = "1.5"; +static c10::IValue version_val = "2.0"; static c10::IValue pg_config_key = "pg_config"; static c10::IValue record_id_key = "record_id"; static c10::IValue pg_id_key = "pg_id"; static c10::IValue pg_name_key = "process_group"; -static c10::IValue seq_id_key = "seq_id"; +static c10::IValue collective_seq_id_key = "collective_seq_id"; +static c10::IValue p2p_seq_id_key = "p2p_seq_id"; +static c10::IValue is_p2p_key = "is_p2p"; static c10::IValue op_id_key = "op_id"; static c10::IValue profiling_name_key = "profiling_name"; static c10::IValue input_sizes_key = "input_sizes"; @@ -428,11 +436,14 @@ struct NCCLTraceBuffer { size_t pg_id_; std::tuple pg_name_; // - // Both seq_id_ and op_id_ are per_pg incrementing counters - // seq_id refers to actual kernel launches (e.g. 1 per coalesced group) - // op_id refers to logical operations (e.g. one per op inside coalesced - // group) - size_t seq_id_; + // collective_seq_id and p2p_seq_id refer to actual kernel launches (e.g. 1 + // per coalesced group). + // collective_seq_id only increments for true collective operations (over + // all ranks in the group). p2p_seq_id only increments over non-collective + // operations in the group. op_id refers to logical operations (e.g. one per + // op inside coalesced group) + size_t collective_seq_id_; + size_t p2p_seq_id_; size_t op_id_; std::string profiling_name_; @@ -445,6 +456,10 @@ struct NCCLTraceBuffer { // timestamp when the entry was created, likely close to the time the work // was 'enqueued'- not necessarily started c10::time_t time_created_; + + // Is this a P2P event? + bool isP2P_; + std::optional duration_; // timestamp when our CPU threads discovered that the kernel started. @@ -479,13 +494,15 @@ struct NCCLTraceBuffer { std::optional record( size_t pg_id, const std::tuple& pg_name, - size_t seq_id, + size_t collective_seq_id, + size_t p2p_seq_id, size_t op_id, std::string profiling_name, const std::vector& inputs, const std::vector& outputs, Event* start, - Event* end) { + Event* end, + bool isP2P) { if (!enabled_) { return c10::nullopt; } @@ -497,13 +514,15 @@ struct NCCLTraceBuffer { id_, pg_id, pg_name, - seq_id, + collective_seq_id, + p2p_seq_id, op_id, std::move(profiling_name), std::move(traceback), std::move(start), std::move(end), - c10::getTime()}; + c10::getTime(), + isP2P}; for (const auto& input : inputs) { c10::IntArrayRef sizes = input.sizes(); @@ -656,7 +675,8 @@ struct NCCLTraceBuffer { dict.insert(record_id_key, int64_t(e.id_)); dict.insert(pg_id_key, int64_t(e.pg_id_)); dict.insert(pg_name_key, e.pg_name_); - dict.insert(seq_id_key, int64_t(e.seq_id_)); + dict.insert(collective_seq_id_key, int64_t(e.collective_seq_id_)); + dict.insert(p2p_seq_id_key, int64_t(e.p2p_seq_id_)); dict.insert(op_id_key, int64_t(e.op_id_)); dict.insert(profiling_name_key, e.profiling_name_); dict.insert(time_created_key, int64_t(e.time_created_)); @@ -699,6 +719,7 @@ struct NCCLTraceBuffer { ? int64_t(*e.time_discovered_completed_) : c10::IValue()); dict.insert(retired_key, e.retired_); + dict.insert(is_p2p_key, e.isP2P_); auto frames = new_list(); for (int64_t frame : tb) {