Skip to content
Permalink
Browse files

Perf test updates

Summary: Updates

Reviewed By: csummersea

Differential Revision: D20588455

fbshipit-source-id: ea9dd9444bc259ec3a352f3589f8c7045474ef2b
  • Loading branch information
jackm321 authored and facebook-github-bot committed Mar 26, 2020
1 parent b75a5d2 commit 54cbaf587ee78ab2365819ea24555a62ac22aebd
Showing with 15 additions and 3 deletions.
  1. +3 −3 torch_glow/src/CachingGraphRunner.cpp
  2. +4 −0 torch_glow/src/PyTorchCommon.h
  3. +8 −0 torch_glow/src/binding.cpp
@@ -147,7 +147,9 @@ CachingGraphRunner::loadImpl(torch::jit::Stack &stack,
cctx.backendOpts.backendSpecificOpts.insert(loadBackendSpecificOpts);
}

RETURN_IF_ERR(hostManager_->addNetwork(std::move(module), cctx));
RETURN_IF_ERR(hostManager_->addNetwork(std::move(module), cctx,
settings_.saturateHost));

TRACE_EVENT_END(traceContext, TraceLevel::RUNTIME, "addNetwork");

auto ret = perGlowGraphInfoMap_.emplace(hash, info);
@@ -422,8 +424,6 @@ CachingGraphRunner::CachingGraphRunner(
hostManager_(hostManager), settings_(settings) {}

CachingGraphRunner::~CachingGraphRunner() {
aggregateAndDumpTraces(nullptr, /* flush */ true);

// Remove Glow functions saved in HostManager when being destroyed.
std::unique_lock<std::shared_timed_mutex> wlock(graphInfoMapMutex);
for (auto &kv : perGlowGraphInfoMap_) {
@@ -91,6 +91,10 @@ struct PyTorchLoaderSettings {

/// Name of a YAML file containing backend specific options.
std::string backendOptionsFile;

/// Whether not to set the saturateHost flag (use all available device) when
/// adding networks to HostManager.
bool saturateHost = false;
};

/// Given a PyTorch ScalarType \p ty, \returns a matching Glow ElemKind.
@@ -103,6 +103,14 @@ PYBIND11_MODULE(_torch_glow, m) {
m.def("disable_write_to_onnx",
[]() { getPyTorchLoaderSettings().writeToOnnx = false; });

/// Enable saturateHost mode in Glow runtime.
m.def("enable_saturate_host",
[]() { getPyTorchLoaderSettings().saturateHost = true; });

/// Disable saturateHost mode in Glow runtime.
m.def("disable_saturate_host",
[]() { getPyTorchLoaderSettings().saturateHost = false; });

/// Add all of the symbols in \p blacklist to the fusion blacklist so that
/// nodes with these symbols will not be fused to Glow.
m.def("setFusionBlacklist", [](const std::vector<std::string> &blacklist) {

0 comments on commit 54cbaf5

Please sign in to comment.
You can’t perform that action at this time.