Skip to content
Permalink
Browse files

Extended PyTorchFileLoader class. (#3391)

Summary:
Added Fusion JIT Graph loading into Glow Function.
Pull Request resolved: #3391

Test Plan: run unit tests

Differential Revision: D16678267

fbshipit-source-id: 4030849a886ca01b5c3793cd3c5130ef4c83317f
  • Loading branch information...
putivsky authored and facebook-github-bot committed Aug 11, 2019
1 parent 841093a commit 38d00aa3129b80b61b926bb06c026950c9867065
@@ -185,6 +185,7 @@ if (GLOW_BUILD_PYTORCH_INTEGRATION)
endif()

add_subdirectory(torch_glow/src)
add_subdirectory(torch_glow/tests/unittests)
endif()

if (GLOW_WITH_BUNDLES AND NOT GLOW_WITH_CPU)
@@ -115,6 +115,7 @@ def _run_cmake(self):
with cd(CMAKE_BUILD_DIR):
cmake_args = [
CMAKE,
'-DC10_USE_GLOG=1',
'-DGLOW_BUILD_PYTORCH_INTEGRATION=ON',
'-DBUILD_SHARED_LIBS=OFF',
'-DCMAKE_EXPORT_COMPILE_COMMANDS=ON',
@@ -124,6 +125,7 @@ def _run_cmake(self):
# PyTorch cmake args
'-DPYTORCH_DIR={}'.format(
os.path.dirname(os.path.realpath(torch.__file__))),
'-DTORCH_GLOW={}'.format(FILE_DIR),
]

if args.cmake_prefix_path:
@@ -11,26 +11,49 @@ message(STATUS "Using pytorch dir ${PYTORCH_DIR}")

link_directories(${PYTORCH_DIR}/lib)

pybind11_add_module(_torch_glow
binding.cpp
PyTorchCommon.cpp
PyTorchFileLoader.cpp
FusingOptimizer.cpp
PyTorchModelLoader.cpp
CachingGraphRunner.cpp)

target_compile_options(_torch_glow PRIVATE -frtti -fexceptions)
target_link_libraries(_torch_glow
add_library(PyTorchModelLoader
PyTorchCommon.cpp
FusingOptimizer.cpp
PyTorchModelLoader.cpp
CachingGraphRunner.cpp)
target_compile_options(PyTorchModelLoader
PRIVATE
-frtti -fexceptions -DC10_USE_GLOG)
target_link_libraries(PyTorchModelLoader
PRIVATE
torch
c10
Support
ExecutionEngine
Graph
Importer
Backends
Backends)

add_library(PyTorchFileLoader
PyTorchFileLoader.cpp)
target_compile_options(PyTorchFileLoader
PRIVATE
-frtti -fexceptions -DC10_USE_GLOG)
target_link_libraries(PyTorchFileLoader
PRIVATE
PyTorchModelLoader)

pybind11_add_module(_torch_glow
binding.cpp)

target_compile_options(_torch_glow
PRIVATE
-frtti -fexceptions -DC10_USE_GLOG)
target_link_libraries(_torch_glow
PRIVATE
PyTorchModelLoader
pybind11)

target_include_directories(PyTorchModelLoader PUBLIC
${PYTORCH_DIR}/include)

target_include_directories(PyTorchFileLoader PUBLIC
${PYTORCH_DIR}/include)

target_include_directories(_torch_glow PUBLIC
${PYTORCH_DIR}/include
)
${PYTORCH_DIR}/include)
@@ -20,8 +20,8 @@
#include <torch/csrc/jit/ir.h>

namespace glow {

/// Performs specific fusion for Linear operator.
void FuseLinear(std::shared_ptr<torch::jit::Graph> &graph);
}
} // namespace glow

#endif // GLOW_TORCH_GLOW_SRC_FUSINGOPTIMIZER_H
@@ -15,6 +15,9 @@
*/

#include "PyTorchCommon.h"
#include "FusingOptimizer.h"
#include "PyTorchModelLoader.h"
#include <torch/csrc/jit/passes/graph_fuser.h>

namespace glow {

@@ -28,4 +31,17 @@ const c10::Symbol &getGlowSymbol() {
at::Symbol::fromQualString("glow::FusionGroup");
return glowSymbol;
}

void glowCustomFuse(std::shared_ptr<torch::jit::Graph> &g,
at::Symbol fuseSymbol) {
// Fuse all linear operators
// Currently PyTorch does not have good support for aten:addmm when fusing
// Therefore we use some pattern to translate all aten::addmm to
// aten::linear before we fuse the whole graph.
FuseLinear(g);

torch::jit::CustomFuseGraph(g, PyTorchModelLoader::isNodeSupported,
fuseSymbol);
}

} // namespace glow
@@ -40,6 +40,11 @@ PyTorchLoaderSettings &getPyTorchLoaderSettings();
/// \returns the PyTorch symbol to be used for the PyTorch node which represents
/// the subgraph that Glow will compile and run.
const c10::Symbol &getGlowSymbol();

/// Executes custom fuse pass for the given \p graph and \p fuseSymbol.
void glowCustomFuse(std::shared_ptr<torch::jit::Graph> &graph,
at::Symbol fuseSymbol);

} // namespace glow

#endif // GLOW_TORCH_GLOW_SRC_COMMON_H
@@ -15,30 +15,189 @@
*/

#include "PyTorchFileLoader.h"
#include "PyTorchModelLoader.h"
#include "glow/Support/Error.h"
#include "glow/Support/Support.h"
#include <ATen/core/grad_mode.h>
#include <torch/csrc/jit/operator_options.h>
#include <torch/csrc/jit/pass_manager.h>

namespace glow {

PyTorchFileLoader::PyTorchFileLoader(
const std::string &fileName,
std::shared_ptr<torch::jit::script::Module> &module, llvm::Error *errPtr) {
auto setup = [&]() -> llvm::Error {
try {
module = std::make_shared<torch::jit::script::Module>(
torch::jit::load(fileName));
} catch (const std::exception &x) {
RETURN_ERR(strFormat("Cannot load model from file: %s, , reason: %s",
fileName.c_str(), x.what()));
namespace {
/// struct keeps pointers to Glow Function and Placeholders as a local thread
/// variable.
struct LocalFusionFunction {
glow::Function *function{nullptr};
std::vector<glow::Placeholder *> *inputPlaceholders{nullptr};
std::vector<glow::Placeholder *> *outputPlaceholders{nullptr};
const PyTorchLoaderSettings *settings{nullptr};

void set(glow::Function *f, std::vector<glow::Placeholder *> *in,
std::vector<glow::Placeholder *> *out,
const PyTorchLoaderSettings *s) {
function = f;
inputPlaceholders = in;
outputPlaceholders = out;
settings = s;
}

void reset() { set(nullptr, nullptr, nullptr, nullptr); }
};

/// Meyer's singleton for static fusion symbol.
static at::Symbol getFusionSymbol() {
/// Fusion pass unique symbol.
static const auto fusionSymbol =
at::Symbol::fromQualString("glow::LoaderFusionPass");
return fusionSymbol;
}

/// Local thread variable, used in PyTorch Custom Fusion pass.
static thread_local LocalFusionFunction localFusionInfo;

/// Loads JIT Graph into Glow Function.
llvm::Error
loadJitGraphToGlowFunction(torch::jit::Stack &stack, torch::jit::Graph &graph,
glow::Function &f,
std::vector<glow::Placeholder *> &inputPlaceholders,
std::vector<glow::Placeholder *> &outputPlaceholders,
const PyTorchLoaderSettings &settings) {
const auto &graphInputs = graph.inputs();
const auto numInputs = graphInputs.size();
auto inputs = torch::jit::last(stack, numInputs);

// Load JIT Graph into Glow Function.
RETURN_IF_ERR(PyTorchModelLoader::loadJITGraph(
f, graph, inputs, inputPlaceholders, outputPlaceholders, settings));

// Remove from stack input parameters.
torch::jit::drop(stack, numInputs);
// Lookup placeholders for the output shapes.
for (auto *ph : outputPlaceholders) {
std::vector<int64_t> sizes;
for (auto size : ph->dims()) {
sizes.push_back(static_cast<int64_t>(size));
}
return llvm::Error::success();
};
// Create an empty tensor with the correct shape.
auto ptT = at::empty(sizes);
auto var = torch::autograd::make_variable(ptT);
// Add to stack output parameter.
stack.push_back(at::IValue(var));
}

return llvm::Error::success();
}

/// Runs Module forward pass, triggers custom fusion pass if local thread
/// Glow function is set.
llvm::Error
evaluateModuleGraph(std::shared_ptr<torch::jit::script::Module> &module,
std::vector<torch::jit::IValue> &inputs) {
try {
module->forward(inputs);
} catch (const std::exception &x) {
RETURN_ERR(x.what());
}
return llvm::Error::success();
}

if (errPtr) {
*errPtr = setup();
} else {
EXIT_ON_ERR(setup());
/// Helper struct, which on constructor registers custom fusion pass
/// and operator.
struct RegisterCustomFusionPass {
RegisterCustomFusionPass() {
auto options = c10::OperatorOptions();
options.setAliasAnalysis(at::AliasAnalysisKind::PURE);
torch::jit::RegisterOperators op({torch::jit::Operator(
getFusionSymbol(),
[](const torch::jit::Node *node) {
return [node](torch::jit::Stack &stack) {
// Get JIT Graph.
auto graph = node->g(at::attr::Subgraph);
auto err = loadJitGraphToGlowFunction(
stack, *graph, *localFusionInfo.function,
*localFusionInfo.inputPlaceholders,
*localFusionInfo.outputPlaceholders, *localFusionInfo.settings);
if (static_cast<bool>(err)) {
throw std::invalid_argument(llvm::toString(std::move(err)));
}
return 0;
};
},
options)});

torch::jit::RegisterPass pass([&](std::shared_ptr<torch::jit::Graph> &g) {
// Trigger custom fusion pass only if local thread Glow Function is set.
if (localFusionInfo.function) {
glowCustomFuse(g, getFusionSymbol());
}
});
}
};
} // namespace

/*static*/
llvm::Error PyTorchFileLoader::loadPyTorchModel(
const std::string &fileName,
std::shared_ptr<torch::jit::script::Module> &module) {
try {
module = std::make_shared<torch::jit::script::Module>(
torch::jit::load(fileName));
} catch (const std::exception &x) {
RETURN_ERR(strFormat("Cannot load model from file: %s, , reason: %s",
fileName.c_str(), x.what()));
}
return llvm::Error::success();
}

/*static*/
llvm::Error PyTorchFileLoader::loadPyTorchGraph(
const std::string &fileName, std::vector<torch::jit::IValue> &inputs,
glow::Function &F, std::vector<glow::Placeholder *> &inputPlaceholders,
std::vector<glow::Placeholder *> &outputPlaceholders,
const PyTorchLoaderSettings &settings, bool sanityCheck) {
// Register custom pass once.
static RegisterCustomFusionPass fuser;

// Convert PyTorch model into JIT Module.
std::shared_ptr<torch::jit::script::Module> module;
RETURN_IF_ERR(PyTorchFileLoader::loadPyTorchModel(fileName, module));

// Disable gradient nodes generation.
at::NoGradGuard guard;

// Set thread local pointer, not null values activate custom fusion pass.
localFusionInfo.set(&F, &inputPlaceholders, &outputPlaceholders, &settings);
auto err = evaluateModuleGraph(module, inputs);
localFusionInfo.reset();

RETURN_IF_ERR(err);

return sanityCheck ? performSanityCheck() : llvm::Error::success();
}

// Sanity check, after "fusionSymbol" node, not other nodes should exist.
/*static*/
llvm::Error PyTorchFileLoader::performSanityCheck() {
std::shared_ptr<torch::jit::Graph> subgraph =
torch::jit::lastExecutedOptimizedGraph();
size_t fusedNodes = 0, missedNodes = 0;
const auto symbol = getFusionSymbol();
for (const auto &node : subgraph->nodes()) {
if (node->kind() == symbol) {
++fusedNodes;
} else if (fusedNodes) {
// fusionSymbol node has been found, output missing node information.
LOG(ERROR) << "Missing node: " << *node;
++missedNodes;
}
}

RETURN_ERR_IF_NOT(
fusedNodes == 1 && missedNodes == 0,
glow::strFormat("Fused optimized nodes: %lu, missing nodes: %lu",
fusedNodes, missedNodes));
return llvm::Error::success();
}

} // namespace glow
@@ -17,19 +17,39 @@
#ifndef GLOW_TORCH_GLOW_SRC_PYTORCHFILELOADER_H
#define GLOW_TORCH_GLOW_SRC_PYTORCHFILELOADER_H

#include "PyTorchCommon.h"
#include "glow/Graph/Graph.h"
#include "llvm/Support/Error.h"
#include <torch/csrc/jit/import.h>

namespace glow {

/// Loads PyTorch model from file to JIT IR subgraphs.
class PyTorchFileLoader {
/// Performs sanity check making sure custom fuse pass succeeded as expected,
/// \returns error otherwise.
static llvm::Error performSanityCheck();

public:
/// Takes a model file \p fileName, loads model into torch Module \p module,
/// reports error assigning \p errPtr.
PyTorchFileLoader(const std::string &fileName,
std::shared_ptr<torch::jit::script::Module> &module,
llvm::Error *errPtr);
/// \returns error if any.
static llvm::Error
loadPyTorchModel(const std::string &fileName,
std::shared_ptr<torch::jit::script::Module> &module);

/// Takes a model file \p fileName, loads optimized graph into Glow Function
/// \p F and input \p inputPlaceholders, output \p outputPlaceholders
/// placeholders, \returns error if any. Parameter \p settings
/// control the fusion details. Optionally method performs a sanity check if
/// \p sanityCheck is set.
/// Method is thread safe, internally it uses local thread structures for
/// executing custom fusion pass, registered globally. No other passes or
/// other treads calling this method will be affected.
static llvm::Error loadPyTorchGraph(
const std::string &fileName, std::vector<torch::jit::IValue> &inputs,
glow::Function &F, std::vector<glow::Placeholder *> &inputPlaceholders,
std::vector<glow::Placeholder *> &outputPlaceholders,
const PyTorchLoaderSettings &settings, bool sanityCheck = true);
};

} // namespace glow

0 comments on commit 38d00aa

Please sign in to comment.
You can’t perform that action at this time.