From dc5e36dd7bfde15842fa1e939dd47440e6b5eb5c Mon Sep 17 00:00:00 2001 From: Yinghai Lu Date: Mon, 2 Mar 2020 23:27:37 -0800 Subject: [PATCH] Enable network sharing in torch_glow (#4235) Summary: Pull Request resolved: https://github.com/pytorch/glow/pull/4235 We will compile and cache the glow fused graph only once, shared across all the module/function from different inference threads. We have a global singleton of CachingGraphRunner keyed on the hash of jit::Graph. For each unique jit::Graph, we have one CachingGraphRunner. Then for each shape specifalization of CachingGraphRunner, we have a unique Glow function. Strategy on caching CachingGraphRunner: We will first query by hashing the fused subgraph. If not found, we will query by hashing the node symbol string, which suits the AOT case. Currently, we have to leak the CachingGraphRunner. This is fine for the forseeable future but we probably need more application level (say, predictor) signal to help with this. For example, if the application framework can save an instance of the CachingGraphRunner as a shared_ptr in AOT case and free it when the model is unloaded, we can then use weak_ptr for the graph runner map and hence will make sure everything will release except for an null weak_ptr. Reviewed By: jackm321 Differential Revision: D20174479 fbshipit-source-id: aa0a564700ded2a02c988b2cad535d5a10c6cfb8 --- .circleci/test.sh | 1 + torch_glow/setup.cfg | 2 +- torch_glow/src/CachingGraphRunner.cpp | 45 +++++++++----- torch_glow/src/CachingGraphRunner.h | 10 ++- torch_glow/src/Registration.cpp | 88 ++++++++++++++++++++++----- torch_glow/src/Registration.h | 20 +++--- torch_glow/tests/utils.py | 6 ++ 7 files changed, 129 insertions(+), 43 deletions(-) diff --git a/.circleci/test.sh b/.circleci/test.sh index 6c6361ec8d..740912e763 100755 --- a/.circleci/test.sh +++ b/.circleci/test.sh @@ -52,6 +52,7 @@ run_pytorch_tests() { export PATH="/tmp/sccache:$PATH" fi source /tmp/venv/bin/activate + pip install pytest-xdist python "${GLOW_SRC}/torch_glow/setup.py" test --run_cmake cd - if hash sccache 2>/dev/null; then diff --git a/torch_glow/setup.cfg b/torch_glow/setup.cfg index 8b109b9904..3a638e5c28 100644 --- a/torch_glow/setup.cfg +++ b/torch_glow/setup.cfg @@ -3,4 +3,4 @@ test=pytest [tool:pytest] testpaths = tests -addopts = --verbose +addopts = --verbose --forked diff --git a/torch_glow/src/CachingGraphRunner.cpp b/torch_glow/src/CachingGraphRunner.cpp index 0f3db73a0b..c368500721 100644 --- a/torch_glow/src/CachingGraphRunner.cpp +++ b/torch_glow/src/CachingGraphRunner.cpp @@ -82,11 +82,10 @@ void CachingGraphRunner::aggregateAndDumpTraces(TraceContext *traceContext, } } -Expected +Expected> CachingGraphRunner::loadImpl(torch::jit::Stack &stack, TraceContext *traceContext) { TRACE_EVENT_SCOPE(traceContext, TraceLevel::RUNTIME, "torch_glow::loadImpl"); - const auto inputs = torch::jit::last(stack, graph_->inputs().size()); TRACE_EVENT_BEGIN(traceContext, TraceLevel::RUNTIME, "computeGraphHash"); @@ -95,11 +94,19 @@ CachingGraphRunner::loadImpl(torch::jit::Stack &stack, // If we already have a Glow function compiled for this graph with and the // given inputs then use that. + { + std::shared_lock rlock(graphInfoMapMutex); + auto it = perGlowGraphInfoMap_.find(hash); + if (it != perGlowGraphInfoMap_.end()) { + return it->second; + } + } + + std::unique_lock wlock(graphInfoMapMutex); auto it = perGlowGraphInfoMap_.find(hash); if (it != perGlowGraphInfoMap_.end()) { - return it->second.get(); + return it->second; } - auto info = std::make_shared(); info->functionName = strFormat("pt_function_%lu", hash); @@ -130,9 +137,10 @@ CachingGraphRunner::loadImpl(torch::jit::Stack &stack, RETURN_IF_ERR(hostManager_->addNetwork(std::move(module), cctx)); TRACE_EVENT_END(traceContext, TraceLevel::RUNTIME, "addNetwork"); - perGlowGraphInfoMap_[hash] = std::move(info); + auto ret = perGlowGraphInfoMap_.emplace(hash, info); + CHECK(ret.second); - return perGlowGraphInfoMap_[hash].get(); + return ret.first->second; } Error CachingGraphRunner::runImpl( @@ -234,7 +242,6 @@ Error CachingGraphRunner::runImpl( } Error CachingGraphRunner::run(torch::jit::Stack &stack) { - PerGlowGraphInfo *info; std::unique_ptr ctx = glow::make_unique(); TraceContext *traceContext = nullptr; @@ -246,9 +253,9 @@ Error CachingGraphRunner::run(torch::jit::Stack &stack) { TRACE_EVENT_BEGIN(traceContext, TraceLevel::RUNTIME, "torch_glow::run"); + std::shared_ptr info; ASSIGN_VALUE_OR_RETURN_ERR(info, loadImpl(stack, traceContext)); - - auto err = runImpl(*DCHECK_NOTNULL(info), stack, ctx); + auto err = runImpl(*DCHECK_NOTNULL(info.get()), stack, ctx); // Reset the traceContext again in case it was changed during run. traceContext = ctx->getTraceContext(); @@ -261,15 +268,19 @@ Error CachingGraphRunner::run(torch::jit::Stack &stack) { } Error CachingGraphRunner::runOnly(torch::jit::Stack &stack) { - if (perGlowGraphInfoMap_.size() != 1) { - return MAKE_ERR(strFormat( - "There should be one and only one compiled graph, but got %lu", - perGlowGraphInfoMap_.size())); + std::shared_ptr info; + { + std::shared_lock rlock(graphInfoMapMutex); + if (perGlowGraphInfoMap_.size() != 1) { + return MAKE_ERR(strFormat( + "There should be one and only one compiled graph, but got %lu", + perGlowGraphInfoMap_.size())); + } + info = perGlowGraphInfoMap_.at(0); } std::unique_ptr ctx = glow::make_unique(); - - return runImpl(*DCHECK_NOTNULL(perGlowGraphInfoMap_[0].get()), stack, ctx); + return runImpl(*DCHECK_NOTNULL(info.get()), stack, ctx); } Error CachingGraphRunner::warmCache(const std::vector &inputMeta) { @@ -281,6 +292,7 @@ Error CachingGraphRunner::warmCache(const std::vector &inputMeta) { return MAKE_ERR("No graph found!"); } + std::unique_lock wlock(graphInfoMapMutex); if (perGlowGraphInfoMap_.size() != 0) { return MAKE_ERR(strFormat("There is already a compiled graph!")); } @@ -300,7 +312,7 @@ Error CachingGraphRunner::warmCache(const std::vector &inputMeta) { RETURN_IF_ERR(hostManager_->addNetwork(std::move(glowModule), cctx)); // Randomly picked one key. There should be only one element in the map // when model is precompiled. - perGlowGraphInfoMap_[0] = std::move(info); + perGlowGraphInfoMap_[0] = info; return Error::success(); } @@ -318,6 +330,7 @@ CachingGraphRunner::~CachingGraphRunner() { aggregateAndDumpTraces(nullptr, /* flush */ true); // Remove Glow functions saved in HostManager when being destroyed. + std::unique_lock wlock(graphInfoMapMutex); for (auto &kv : perGlowGraphInfoMap_) { ERR_TO_BOOL(hostManager_->removeNetwork(kv.second->functionName)); } diff --git a/torch_glow/src/CachingGraphRunner.h b/torch_glow/src/CachingGraphRunner.h index fffbba9969..0ba99b10e2 100644 --- a/torch_glow/src/CachingGraphRunner.h +++ b/torch_glow/src/CachingGraphRunner.h @@ -21,9 +21,10 @@ #include "glow/Runtime/HostManager/HostManager.h" #include - #include +#include + namespace glow { /// For a given PyTorch JIT graph, this class is responsible for maintaining a @@ -70,14 +71,17 @@ class CachingGraphRunner { /// in groups to file. std::unique_ptr mergedTraceContext_; + /// Lock for concurrent accessing to perGlowGraphInfoMap_. + std::shared_timed_mutex graphInfoMapMutex; + /// Given a PyTorch input stack \p stack, this generates a hash from the /// values on the stack and checks to see if a matching function was loaded /// previously. If a matching function was loaded previously then its cached /// info is returned immediately. Otherwise this loads the /// subgraph into the owned HostManager, creates a PerGlowGraphInfo which is /// cached for the given inputs, and then \returns this PerGlowGraphInfo. - Expected loadImpl(torch::jit::Stack &stack, - TraceContext *traceContext); + Expected> + loadImpl(torch::jit::Stack &stack, TraceContext *traceContext); /// Given a PerGlowGraphInfo \p info for a subgraph that was previously /// loaded, this runs the Glow function that corresponds to that diff --git a/torch_glow/src/Registration.cpp b/torch_glow/src/Registration.cpp index 011382b5e4..150c19f04a 100644 --- a/torch_glow/src/Registration.cpp +++ b/torch_glow/src/Registration.cpp @@ -21,39 +21,83 @@ #include #include +#include #include #include #include #include #include #include +#include + +#include +#include namespace glow { -std::unordered_map> & + +namespace { +/// Lock to protect the global graph runner map. +std::shared_timed_mutex runnerMapMutex; + +size_t hashBlock(torch::jit::Block *block, size_t seed) { + size_t s = seed; + for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) { + auto *node = *it; + s = torch::hash_combine(s, torch::jit::HashNode{}(node)); + if (!node->blocks().empty()) { + for (auto b : node->blocks()) { + s = hashBlock(b, s); + } + } + } + return s; +} + +size_t hashGraph(std::shared_ptr g) { + return hashBlock(g->block(), 0xcaffe2); +} + +std::unordered_map> & getPreloadedRunnerMap() { - static std::unordered_map> + static std::unordered_map> preloadedGraphRunners_; return preloadedGraphRunners_; } +} // namespace -std::unique_ptr +size_t getGraphRunnerMapSize() { + const auto &m = getPreloadedRunnerMap(); + std::shared_lock rlock(runnerMapMutex); + return m.size(); +} + +std::shared_ptr getGraphRunnerForKey(const std::string &key) { auto &preloadedRunners = getPreloadedRunnerMap(); + std::shared_lock rlock(runnerMapMutex); auto it = preloadedRunners.find(key); if (it == preloadedRunners.end()) { return nullptr; } else { - auto graphRunner = std::move(it->second); - preloadedRunners.erase(it); - return graphRunner; + return it->second; } } -void setGraphRunnerForKey(const std::string &key, - std::unique_ptr graphRunner) { +std::shared_ptr +setGraphRunnerForKey(const std::string &key, + std::function(void)> + graphRunnerBuilder) { auto &preloadedRunners = getPreloadedRunnerMap(); - DCHECK_EQ(preloadedRunners.count(key), 0); - preloadedRunners[key] = std::move(graphRunner); + std::unique_lock wlock(runnerMapMutex); + const auto it = preloadedRunners.find(key); + if (it != preloadedRunners.end()) { + return it->second; + } + + auto runner = graphRunnerBuilder(); + auto ret = preloadedRunners.emplace(key, runner); + CHECK(ret.second); + return runner; } void registerGlowOp(const c10::Symbol &symbol) { @@ -63,17 +107,29 @@ void registerGlowOp(const c10::Symbol &symbol) { torch::jit::RegisterOperators op({torch::jit::Operator( symbol, [](const torch::jit::Node *node) -> torch::jit::Operation { - std::string key = node->kind().toQualString(); - + // How to find a graphRunner: + // 1. See if a key based on graph hash has been registered + // 2. If not, see if a key based on fusion node symbol string has been + // registered, which is usually done in AOT fashion + // 3. If not, create a graphRunner with graph hash as a key + size_t key = hashGraph(node->g(at::attr::Subgraph)); + std::string keyStr(sizeof(size_t), '\0'); + std::memcpy(&keyStr[0], &key, sizeof(key)); std::shared_ptr graphRunner = - getGraphRunnerForKey(key); + getGraphRunnerForKey(keyStr); + + if (!graphRunner) { + graphRunner = getGraphRunnerForKey(node->kind().toQualString()); + } // If no preloaded graph runner was created for this node, create a new // empty one. if (!graphRunner) { - graphRunner = std::make_shared( - node->g(at::attr::Subgraph), getHostManager(), - getPyTorchLoaderSettings()); + graphRunner = setGraphRunnerForKey(keyStr, [node]() { + return std::make_shared( + node->g(at::attr::Subgraph), getHostManager(), + getPyTorchLoaderSettings()); + }); } return [graphRunner](torch::jit::Stack &stack) { diff --git a/torch_glow/src/Registration.h b/torch_glow/src/Registration.h index e5fb90af86..3ce5fca240 100644 --- a/torch_glow/src/Registration.h +++ b/torch_glow/src/Registration.h @@ -35,16 +35,22 @@ void registerGlowFusionPass(std::function enablePassFn); /// registered. void registerGlowFusionOpAndPass(std::function enablePassFn); -/// Store a CachingGraphRunner \p graphRunner under a given \p key for later -/// later use. This is so that a CachingGraphRunner can be preloaded for a given -/// graph and then stored until the corresponding pt node is created for that -/// graph. -void setGraphRunnerForKey(const std::string &key, - std::unique_ptr graphRunner); +/// Get the size of the global CachingGraphRunner map. For testing and debugging +/// purpose only. +size_t getGraphRunnerMapSize(); + +/// Store a CachingGraphRunner with a constucting functor \p graphRunnerBuilder +/// under a given \p key for later later use. This is so that a +/// CachingGraphRunner can be preloaded for a given graph and then stored until +/// the corresponding pt node is created for that graph. +std::shared_ptr +setGraphRunnerForKey(const std::string &key, + std::function(void)> + graphRunnerBuilder); /// Get a precreated CachingGraphRunner for a given \p key. \returns nullptr if /// no CachingGraphRunner was registered for the given key. -std::unique_ptr +std::shared_ptr getGraphRunnerForKey(const std::string &key); } // namespace glow diff --git a/torch_glow/tests/utils.py b/torch_glow/tests/utils.py index 643b7649e0..8feebaeb3e 100644 --- a/torch_glow/tests/utils.py +++ b/torch_glow/tests/utils.py @@ -3,6 +3,8 @@ import torch import torch_glow +import sys + GLOW_NODE_NAME = "glow::FusionGroup" SUBGRAPH_ATTR = "Subgraph" @@ -84,9 +86,13 @@ def jitVsGlow_(f_torch, f_glow, check_trace, atol, rtol, *inputs, expected_fused assert isinstance(torch_res, tuple) and isinstance(glow_res, tuple) assert len(torch_res) == len(glow_res) for i in range(len(torch_res)): + print("torch shape: {}".format(torch_res[i].shape), file=sys.stderr) + print("glow shape: {}".format(glow_res[i].shape), file=sys.stderr) assert torch.allclose( torch_res[i], glow_res[i], atol=atol, rtol=rtol) else: + print("torch shape: {}".format(torch_res.shape), file=sys.stderr) + print("glow shape: {}".format(glow_res.shape), file=sys.stderr) is_all_close = torch.allclose( torch_res, glow_res, atol=atol, rtol=rtol) if not is_all_close: