Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Support more JIT IR nodes to PyTorchModelLoader #3181

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 26 additions & 0 deletions docs/pytorch.md
@@ -0,0 +1,26 @@
# Loading PyTorch models in Glow
**Warning:** PyTorch integration is still under development and does not yet have as much support as Caffe2 model loading.

## About
Import PyTorch models to Glow via the [PyTorch JIT IR](https://pytorch.org/docs/master/jit.html).

See `glow/torch_glow/examples` for illustrative examples.


## Setup
* Follow directions in Building.md to make sure Glow can be built
* install [PyTorch](https://pytorch.org/) nightly
* install torchvision: `pip install torchvision`
* cd to `glow/torch_glow`

## Usage
### Run tests
* `python setup.py test`
* use the `--cmake_prefix_path` flag to specify an llvm install location just like when building glow
* to disable capturing test outputs, add `addopts = -s` to `[tool:pytest]` in setup.cfg
### Temporarily install while developing on Glow
* `python setup.py develop`
* verify with installation worked with `import torch_glow` in Python
### Install
* `python setup.py install`
* verify with installation worked with `import torch_glow` in Python
File renamed without changes.
1 change: 1 addition & 0 deletions torch_glow/setup.cfg
Expand Up @@ -3,3 +3,4 @@ test=pytest

[tool:pytest]
testpaths = tests
addopts = --verbose
18 changes: 14 additions & 4 deletions torch_glow/setup.py
Expand Up @@ -48,11 +48,20 @@
# # Flags
# ################################################################################

# store first argument
assert len(sys.argv) > 0
first_arg = sys.argv[0]

# parse known arguments
parser = argparse.ArgumentParser()
parser.add_argument("--debug", type=bool, default=False, help="Compile with debug on")
parser.add_argument("--run_cmake", action='store_true', default=False, help="Run cmake")
parser.add_argument("--release", action='store_true', default=False, help="Compile with debug on")
parser.add_argument("--cmake_prefix_path", type=str, help="Populates -DCMAKE_PREFIX_PATH")
args = parser.parse_known_args()[0]

# restore first and remaining arguments to argv
arg_parse_res = parser.parse_known_args()
args = arg_parse_res[0]
sys.argv = [first_arg] + arg_parse_res[1]


# ################################################################################
Expand Down Expand Up @@ -97,7 +106,7 @@ def _run_cmake(self):
'-DGLOW_BUILD_PYTORCH_INTEGRATION=ON',
'-DBUILD_SHARED_LIBS=OFF',
'-DCMAKE_EXPORT_COMPILE_COMMANDS=ON',
'-DCMAKE_BUILD_TYPE={}'.format('Debug' if args.debug else 'Release'),
'-DCMAKE_BUILD_TYPE={}'.format('Release' if args.release else 'Debug'),
'-DPYTHON_EXECUTABLE={}'.format(sys.executable),
# PyTorch cmake args
'-DPYTORCH_DIR={}'.format(
Expand Down Expand Up @@ -127,7 +136,8 @@ def run(self):
is_initial_build = not os.path.exists(CMAKE_BUILD_DIR)
if is_initial_build:
os.makedirs(CMAKE_BUILD_DIR)
self._run_cmake()
if is_initial_build or args.run_cmake:
self._run_cmake()
self._run_build()


Expand Down
5 changes: 0 additions & 5 deletions torch_glow/src/CMakeLists.txt
@@ -1,4 +1,3 @@
# PYTORCH_DIR
if(DEFINED ENV{PYTORCH_DIR})
SET(PYTORCH_DIR $ENV{PYTORCH_DIR})
message(STATUS "Using PYTORCH_DIR from env")
Expand All @@ -20,18 +19,14 @@ pybind11_add_module(_torch_glow
target_compile_options(_torch_glow PRIVATE -frtti -fexceptions)
target_link_libraries(_torch_glow
PRIVATE
# pytorch
torch
caffe2
c10
# glow
Support
ExecutionEngine
Graph
Importer
Interpreter
Backends
# pybind
pybind11)

target_include_directories(_torch_glow PUBLIC
Expand Down
15 changes: 7 additions & 8 deletions torch_glow/src/CachingGraphRunner.cpp
Expand Up @@ -23,16 +23,15 @@
GraphInfo::GraphInfo(const torch::jit::Node *node)
: subgraph(node->g(at::attr::Subgraph)) {}

void CachingGraphRunner::addGraph(const torch::jit::Node *node) {
// TODO: just use node* as key
GraphInfo graphInfo(node);
jitNodeToInfoMap_.insert({node, std::move(graphInfo)});
}

void CachingGraphRunner::runGraph(const torch::jit::Node *node,
torch::jit::Stack &stack) {
// Make sure this is a valid id
assert(jitNodeToInfoMap_.count(node) > 0);
// If this is the first time this subgraph has been run then create a new
// GraphInfo object to store information about it.
if (!jitNodeToInfoMap_.count(node)) {
GraphInfo graphInfo(node);
jitNodeToInfoMap_.insert({node, std::move(graphInfo)});
}

auto &graphInfo = jitNodeToInfoMap_.at(node);

const at::ArrayRef<torch::jit::Value *> &graphInputs =
Expand Down
8 changes: 6 additions & 2 deletions torch_glow/src/CachingGraphRunner.h
Expand Up @@ -28,6 +28,8 @@ struct GraphInfo {
GraphInfo(const torch::jit::Node *node);
};

/// Responsible for maintaining a mapping from PyTorch subgraphs and their
/// unique input types to compiled Glow Functions.
class CachingGraphRunner {
/// Map of from PyTorch JIT Node containing contracted JIT subgraph to
/// to GraphInfo containing information relevent to Glow about that subgraph.
Expand All @@ -37,8 +39,10 @@ class CachingGraphRunner {
public:
CachingGraphRunner() = default;

void addGraph(const torch::jit::Node *node);

/// Given a PyTorch glow::CompilationGroup Node \p node that contains a
/// PyTorch subgraph and corresponding PyTorch Stack \p stack of inputs, run
/// that subgraph on those inputs. If this is the first time this node has
/// been seen then this first loads it as a Glow Function and compiles.
void runGraph(const torch::jit::Node *node, torch::jit::Stack &stack);
};

Expand Down