Skip to content
This repository was archived by the owner on Jul 1, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .circleci/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ run_pytorch_tests() {
export PATH="/tmp/sccache:$PATH"
fi
source /tmp/venv/bin/activate
pip install pytest-xdist
python "${GLOW_SRC}/torch_glow/setup.py" test --run_cmake
cd -
if hash sccache 2>/dev/null; then
Expand Down
2 changes: 1 addition & 1 deletion torch_glow/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ test=pytest

[tool:pytest]
testpaths = tests
addopts = --verbose
addopts = --verbose --forked
45 changes: 29 additions & 16 deletions torch_glow/src/CachingGraphRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,10 @@ void CachingGraphRunner::aggregateAndDumpTraces(TraceContext *traceContext,
}
}

Expected<CachingGraphRunner::PerGlowGraphInfo *>
Expected<std::shared_ptr<CachingGraphRunner::PerGlowGraphInfo>>
CachingGraphRunner::loadImpl(torch::jit::Stack &stack,
TraceContext *traceContext) {
TRACE_EVENT_SCOPE(traceContext, TraceLevel::RUNTIME, "torch_glow::loadImpl");

const auto inputs = torch::jit::last(stack, graph_->inputs().size());

TRACE_EVENT_BEGIN(traceContext, TraceLevel::RUNTIME, "computeGraphHash");
Expand All @@ -95,11 +94,19 @@ CachingGraphRunner::loadImpl(torch::jit::Stack &stack,

// If we already have a Glow function compiled for this graph with and the
// given inputs then use that.
{
std::shared_lock<std::shared_timed_mutex> rlock(graphInfoMapMutex);
auto it = perGlowGraphInfoMap_.find(hash);
if (it != perGlowGraphInfoMap_.end()) {
return it->second;
}
}

std::unique_lock<std::shared_timed_mutex> wlock(graphInfoMapMutex);
auto it = perGlowGraphInfoMap_.find(hash);
if (it != perGlowGraphInfoMap_.end()) {
return it->second.get();
return it->second;
}

auto info = std::make_shared<PerGlowGraphInfo>();
info->functionName = strFormat("pt_function_%lu", hash);

Expand Down Expand Up @@ -130,9 +137,10 @@ CachingGraphRunner::loadImpl(torch::jit::Stack &stack,
RETURN_IF_ERR(hostManager_->addNetwork(std::move(module), cctx));
TRACE_EVENT_END(traceContext, TraceLevel::RUNTIME, "addNetwork");

perGlowGraphInfoMap_[hash] = std::move(info);
auto ret = perGlowGraphInfoMap_.emplace(hash, info);
CHECK(ret.second);

return perGlowGraphInfoMap_[hash].get();
return ret.first->second;
}

Error CachingGraphRunner::runImpl(
Expand Down Expand Up @@ -234,7 +242,6 @@ Error CachingGraphRunner::runImpl(
}

Error CachingGraphRunner::run(torch::jit::Stack &stack) {
PerGlowGraphInfo *info;
std::unique_ptr<ExecutionContext> ctx = glow::make_unique<ExecutionContext>();

TraceContext *traceContext = nullptr;
Expand All @@ -246,9 +253,9 @@ Error CachingGraphRunner::run(torch::jit::Stack &stack) {

TRACE_EVENT_BEGIN(traceContext, TraceLevel::RUNTIME, "torch_glow::run");

std::shared_ptr<PerGlowGraphInfo> info;
ASSIGN_VALUE_OR_RETURN_ERR(info, loadImpl(stack, traceContext));

auto err = runImpl(*DCHECK_NOTNULL(info), stack, ctx);
auto err = runImpl(*DCHECK_NOTNULL(info.get()), stack, ctx);

// Reset the traceContext again in case it was changed during run.
traceContext = ctx->getTraceContext();
Expand All @@ -261,15 +268,19 @@ Error CachingGraphRunner::run(torch::jit::Stack &stack) {
}

Error CachingGraphRunner::runOnly(torch::jit::Stack &stack) {
if (perGlowGraphInfoMap_.size() != 1) {
return MAKE_ERR(strFormat(
"There should be one and only one compiled graph, but got %lu",
perGlowGraphInfoMap_.size()));
std::shared_ptr<PerGlowGraphInfo> info;
{
std::shared_lock<std::shared_timed_mutex> rlock(graphInfoMapMutex);
if (perGlowGraphInfoMap_.size() != 1) {
return MAKE_ERR(strFormat(
"There should be one and only one compiled graph, but got %lu",
perGlowGraphInfoMap_.size()));
}
info = perGlowGraphInfoMap_.at(0);
}

std::unique_ptr<ExecutionContext> ctx = glow::make_unique<ExecutionContext>();

return runImpl(*DCHECK_NOTNULL(perGlowGraphInfoMap_[0].get()), stack, ctx);
return runImpl(*DCHECK_NOTNULL(info.get()), stack, ctx);
}

Error CachingGraphRunner::warmCache(const std::vector<InputMeta> &inputMeta) {
Expand All @@ -281,6 +292,7 @@ Error CachingGraphRunner::warmCache(const std::vector<InputMeta> &inputMeta) {
return MAKE_ERR("No graph found!");
}

std::unique_lock<std::shared_timed_mutex> wlock(graphInfoMapMutex);
if (perGlowGraphInfoMap_.size() != 0) {
return MAKE_ERR(strFormat("There is already a compiled graph!"));
}
Expand All @@ -300,7 +312,7 @@ Error CachingGraphRunner::warmCache(const std::vector<InputMeta> &inputMeta) {
RETURN_IF_ERR(hostManager_->addNetwork(std::move(glowModule), cctx));
// Randomly picked one key. There should be only one element in the map
// when model is precompiled.
perGlowGraphInfoMap_[0] = std::move(info);
perGlowGraphInfoMap_[0] = info;
return Error::success();
}

Expand All @@ -318,6 +330,7 @@ CachingGraphRunner::~CachingGraphRunner() {
aggregateAndDumpTraces(nullptr, /* flush */ true);

// Remove Glow functions saved in HostManager when being destroyed.
std::unique_lock<std::shared_timed_mutex> wlock(graphInfoMapMutex);
for (auto &kv : perGlowGraphInfoMap_) {
ERR_TO_BOOL(hostManager_->removeNetwork(kv.second->functionName));
}
Expand Down
10 changes: 7 additions & 3 deletions torch_glow/src/CachingGraphRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
#include "glow/Runtime/HostManager/HostManager.h"

#include <torch/csrc/jit/ir/ir.h>

#include <torch/csrc/jit/serialization/import.h>

#include <shared_mutex>

namespace glow {

/// For a given PyTorch JIT graph, this class is responsible for maintaining a
Expand Down Expand Up @@ -70,14 +71,17 @@ class CachingGraphRunner {
/// in groups to file.
std::unique_ptr<TraceContext> mergedTraceContext_;

/// Lock for concurrent accessing to perGlowGraphInfoMap_.
std::shared_timed_mutex graphInfoMapMutex;

/// Given a PyTorch input stack \p stack, this generates a hash from the
/// values on the stack and checks to see if a matching function was loaded
/// previously. If a matching function was loaded previously then its cached
/// info is returned immediately. Otherwise this loads the
/// subgraph into the owned HostManager, creates a PerGlowGraphInfo which is
/// cached for the given inputs, and then \returns this PerGlowGraphInfo.
Expected<PerGlowGraphInfo *> loadImpl(torch::jit::Stack &stack,
TraceContext *traceContext);
Expected<std::shared_ptr<PerGlowGraphInfo>>
loadImpl(torch::jit::Stack &stack, TraceContext *traceContext);

/// Given a PerGlowGraphInfo \p info for a subgraph that was previously
/// loaded, this runs the Glow function that corresponds to that
Expand Down
88 changes: 72 additions & 16 deletions torch_glow/src/Registration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,39 +21,83 @@

#include <glog/logging.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/ir/node_hashing.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/pass_manager.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
#include <torch/csrc/utils/hash.h>

#include <mutex>
#include <shared_mutex>

namespace glow {
std::unordered_map<std::string, std::unique_ptr<CachingGraphRunner>> &

namespace {
/// Lock to protect the global graph runner map.
std::shared_timed_mutex runnerMapMutex;

size_t hashBlock(torch::jit::Block *block, size_t seed) {
size_t s = seed;
for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
auto *node = *it;
s = torch::hash_combine(s, torch::jit::HashNode{}(node));
if (!node->blocks().empty()) {
for (auto b : node->blocks()) {
s = hashBlock(b, s);
}
}
}
return s;
}

size_t hashGraph(std::shared_ptr<torch::jit::Graph> g) {
return hashBlock(g->block(), 0xcaffe2);
}

std::unordered_map<std::string, std::shared_ptr<CachingGraphRunner>> &
getPreloadedRunnerMap() {
static std::unordered_map<std::string, std::unique_ptr<CachingGraphRunner>>
static std::unordered_map<std::string, std::shared_ptr<CachingGraphRunner>>
preloadedGraphRunners_;
return preloadedGraphRunners_;
}
} // namespace

std::unique_ptr<CachingGraphRunner>
size_t getGraphRunnerMapSize() {
const auto &m = getPreloadedRunnerMap();
std::shared_lock<std::shared_timed_mutex> rlock(runnerMapMutex);
return m.size();
}

std::shared_ptr<CachingGraphRunner>
getGraphRunnerForKey(const std::string &key) {
auto &preloadedRunners = getPreloadedRunnerMap();
std::shared_lock<std::shared_timed_mutex> rlock(runnerMapMutex);
auto it = preloadedRunners.find(key);
if (it == preloadedRunners.end()) {
return nullptr;
} else {
auto graphRunner = std::move(it->second);
preloadedRunners.erase(it);
return graphRunner;
return it->second;
}
}

void setGraphRunnerForKey(const std::string &key,
std::unique_ptr<CachingGraphRunner> graphRunner) {
std::shared_ptr<CachingGraphRunner>
setGraphRunnerForKey(const std::string &key,
std::function<std::shared_ptr<CachingGraphRunner>(void)>
graphRunnerBuilder) {
auto &preloadedRunners = getPreloadedRunnerMap();
DCHECK_EQ(preloadedRunners.count(key), 0);
preloadedRunners[key] = std::move(graphRunner);
std::unique_lock<std::shared_timed_mutex> wlock(runnerMapMutex);
const auto it = preloadedRunners.find(key);
if (it != preloadedRunners.end()) {
return it->second;
}

auto runner = graphRunnerBuilder();
auto ret = preloadedRunners.emplace(key, runner);
CHECK(ret.second);
return runner;
}

void registerGlowOp(const c10::Symbol &symbol) {
Expand All @@ -63,17 +107,29 @@ void registerGlowOp(const c10::Symbol &symbol) {
torch::jit::RegisterOperators op({torch::jit::Operator(
symbol,
[](const torch::jit::Node *node) -> torch::jit::Operation {
std::string key = node->kind().toQualString();

// How to find a graphRunner:
// 1. See if a key based on graph hash has been registered
// 2. If not, see if a key based on fusion node symbol string has been
// registered, which is usually done in AOT fashion
// 3. If not, create a graphRunner with graph hash as a key
size_t key = hashGraph(node->g(at::attr::Subgraph));
std::string keyStr(sizeof(size_t), '\0');
std::memcpy(&keyStr[0], &key, sizeof(key));
std::shared_ptr<CachingGraphRunner> graphRunner =
getGraphRunnerForKey(key);
getGraphRunnerForKey(keyStr);

if (!graphRunner) {
graphRunner = getGraphRunnerForKey(node->kind().toQualString());
}

// If no preloaded graph runner was created for this node, create a new
// empty one.
if (!graphRunner) {
graphRunner = std::make_shared<CachingGraphRunner>(
node->g(at::attr::Subgraph), getHostManager(),
getPyTorchLoaderSettings());
graphRunner = setGraphRunnerForKey(keyStr, [node]() {
return std::make_shared<CachingGraphRunner>(
node->g(at::attr::Subgraph), getHostManager(),
getPyTorchLoaderSettings());
});
}

return [graphRunner](torch::jit::Stack &stack) {
Expand Down
20 changes: 13 additions & 7 deletions torch_glow/src/Registration.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,22 @@ void registerGlowFusionPass(std::function<bool()> enablePassFn);
/// registered.
void registerGlowFusionOpAndPass(std::function<bool()> enablePassFn);

/// Store a CachingGraphRunner \p graphRunner under a given \p key for later
/// later use. This is so that a CachingGraphRunner can be preloaded for a given
/// graph and then stored until the corresponding pt node is created for that
/// graph.
void setGraphRunnerForKey(const std::string &key,
std::unique_ptr<CachingGraphRunner> graphRunner);
/// Get the size of the global CachingGraphRunner map. For testing and debugging
/// purpose only.
size_t getGraphRunnerMapSize();

/// Store a CachingGraphRunner with a constucting functor \p graphRunnerBuilder
/// under a given \p key for later later use. This is so that a
/// CachingGraphRunner can be preloaded for a given graph and then stored until
/// the corresponding pt node is created for that graph.
std::shared_ptr<CachingGraphRunner>
setGraphRunnerForKey(const std::string &key,
std::function<std::shared_ptr<CachingGraphRunner>(void)>
graphRunnerBuilder);

/// Get a precreated CachingGraphRunner for a given \p key. \returns nullptr if
/// no CachingGraphRunner was registered for the given key.
std::unique_ptr<CachingGraphRunner>
std::shared_ptr<CachingGraphRunner>
getGraphRunnerForKey(const std::string &key);
} // namespace glow

Expand Down
6 changes: 6 additions & 0 deletions torch_glow/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch
import torch_glow

import sys

GLOW_NODE_NAME = "glow::FusionGroup"
SUBGRAPH_ATTR = "Subgraph"

Expand Down Expand Up @@ -84,9 +86,13 @@ def jitVsGlow_(f_torch, f_glow, check_trace, atol, rtol, *inputs, expected_fused
assert isinstance(torch_res, tuple) and isinstance(glow_res, tuple)
assert len(torch_res) == len(glow_res)
for i in range(len(torch_res)):
print("torch shape: {}".format(torch_res[i].shape), file=sys.stderr)
print("glow shape: {}".format(glow_res[i].shape), file=sys.stderr)
assert torch.allclose(
torch_res[i], glow_res[i], atol=atol, rtol=rtol)
else:
print("torch shape: {}".format(torch_res.shape), file=sys.stderr)
print("glow shape: {}".format(glow_res.shape), file=sys.stderr)
is_all_close = torch.allclose(
torch_res, glow_res, atol=atol, rtol=rtol)
if not is_all_close:
Expand Down