Skip to content

Commit

Permalink
Support more JIT IR nodes to PyTorchModelLoader (#3181)
Browse files Browse the repository at this point in the history
Summary:
* Add support for loading `Constant`, `add`, `sub`, `conv2d`, and `relu` PyTorch JIT IR Nodes
* Setup operator tests using PyTest and setuptools
* Added proper cl flags handling to setup.py

Documentation:
* doyxgen comments
* Add a short getting started guide in pytorch.md
Pull Request resolved: #3181

Test Plan:
`python setup.py test`

<img width="1438" alt="Screen Shot 2019-06-26 at 2 04 15 PM" src="https://user-images.githubusercontent.com/1740091/60217187-ff3bdf00-9820-11e9-8485-60f628823cde.png">

Differential Revision: D16042625

Pulled By: jackm321

fbshipit-source-id: e6cc14a3f2d60da8dafb5dbc44df4ee6e4e2ce59
  • Loading branch information
jackm321 authored and facebook-github-bot committed Jun 28, 2019
1 parent 296c6b0 commit 408cd8d
Show file tree
Hide file tree
Showing 19 changed files with 633 additions and 83 deletions.
26 changes: 26 additions & 0 deletions docs/pytorch.md
@@ -0,0 +1,26 @@
# Loading PyTorch models in Glow
**Warning:** PyTorch integration is still under development and does not yet have as much support as Caffe2 model loading.

## About
Import PyTorch models to Glow via the [PyTorch JIT IR](https://pytorch.org/docs/master/jit.html).

See `glow/torch_glow/examples` for illustrative examples.


## Setup
* Follow directions in Building.md to make sure Glow can be built
* install [PyTorch](https://pytorch.org/) nightly
* install torchvision: `pip install torchvision`
* cd to `glow/torch_glow`

## Usage
### Run tests
* `python setup.py test`
* use the `--cmake_prefix_path` flag to specify an llvm install location just like when building glow
* to disable capturing test outputs, add `addopts = -s` to `[tool:pytest]` in setup.cfg
### Temporarily install while developing on Glow
* `python setup.py develop`
* verify with installation worked with `import torch_glow` in Python
### Install
* `python setup.py install`
* verify with installation worked with `import torch_glow` in Python
File renamed without changes.
1 change: 1 addition & 0 deletions torch_glow/setup.cfg
Expand Up @@ -3,3 +3,4 @@ test=pytest

[tool:pytest]
testpaths = tests
addopts = --verbose
18 changes: 14 additions & 4 deletions torch_glow/setup.py
Expand Up @@ -48,11 +48,20 @@
# # Flags
# ################################################################################

# store first argument
assert len(sys.argv) > 0
first_arg = sys.argv[0]

# parse known arguments
parser = argparse.ArgumentParser()
parser.add_argument("--debug", type=bool, default=False, help="Compile with debug on")
parser.add_argument("--run_cmake", action='store_true', default=False, help="Run cmake")
parser.add_argument("--release", action='store_true', default=False, help="Compile with debug on")
parser.add_argument("--cmake_prefix_path", type=str, help="Populates -DCMAKE_PREFIX_PATH")
args = parser.parse_known_args()[0]

# restore first and remaining arguments to argv
arg_parse_res = parser.parse_known_args()
args = arg_parse_res[0]
sys.argv = [first_arg] + arg_parse_res[1]


# ################################################################################
Expand Down Expand Up @@ -97,7 +106,7 @@ def _run_cmake(self):
'-DGLOW_BUILD_PYTORCH_INTEGRATION=ON',
'-DBUILD_SHARED_LIBS=OFF',
'-DCMAKE_EXPORT_COMPILE_COMMANDS=ON',
'-DCMAKE_BUILD_TYPE={}'.format('Debug' if args.debug else 'Release'),
'-DCMAKE_BUILD_TYPE={}'.format('Release' if args.release else 'Debug'),
'-DPYTHON_EXECUTABLE={}'.format(sys.executable),
# PyTorch cmake args
'-DPYTORCH_DIR={}'.format(
Expand Down Expand Up @@ -127,7 +136,8 @@ def run(self):
is_initial_build = not os.path.exists(CMAKE_BUILD_DIR)
if is_initial_build:
os.makedirs(CMAKE_BUILD_DIR)
self._run_cmake()
if is_initial_build or args.run_cmake:
self._run_cmake()
self._run_build()


Expand Down
5 changes: 0 additions & 5 deletions torch_glow/src/CMakeLists.txt
@@ -1,4 +1,3 @@
# PYTORCH_DIR
if(DEFINED ENV{PYTORCH_DIR})
SET(PYTORCH_DIR $ENV{PYTORCH_DIR})
message(STATUS "Using PYTORCH_DIR from env")
Expand All @@ -20,18 +19,14 @@ pybind11_add_module(_torch_glow
target_compile_options(_torch_glow PRIVATE -frtti -fexceptions)
target_link_libraries(_torch_glow
PRIVATE
# pytorch
torch
caffe2
c10
# glow
Support
ExecutionEngine
Graph
Importer
Interpreter
Backends
# pybind
pybind11)

target_include_directories(_torch_glow PUBLIC
Expand Down
15 changes: 7 additions & 8 deletions torch_glow/src/CachingGraphRunner.cpp
Expand Up @@ -23,16 +23,15 @@
GraphInfo::GraphInfo(const torch::jit::Node *node)
: subgraph(node->g(at::attr::Subgraph)) {}

void CachingGraphRunner::addGraph(const torch::jit::Node *node) {
// TODO: just use node* as key
GraphInfo graphInfo(node);
jitNodeToInfoMap_.insert({node, std::move(graphInfo)});
}

void CachingGraphRunner::runGraph(const torch::jit::Node *node,
torch::jit::Stack &stack) {
// Make sure this is a valid id
assert(jitNodeToInfoMap_.count(node) > 0);
// If this is the first time this subgraph has been run then create a new
// GraphInfo object to store information about it.
if (!jitNodeToInfoMap_.count(node)) {
GraphInfo graphInfo(node);
jitNodeToInfoMap_.insert({node, std::move(graphInfo)});
}

auto &graphInfo = jitNodeToInfoMap_.at(node);

const at::ArrayRef<torch::jit::Value *> &graphInputs =
Expand Down
8 changes: 6 additions & 2 deletions torch_glow/src/CachingGraphRunner.h
Expand Up @@ -28,6 +28,8 @@ struct GraphInfo {
GraphInfo(const torch::jit::Node *node);
};

/// Responsible for maintaining a mapping from PyTorch subgraphs and their
/// unique input types to compiled Glow Functions.
class CachingGraphRunner {
/// Map of from PyTorch JIT Node containing contracted JIT subgraph to
/// to GraphInfo containing information relevent to Glow about that subgraph.
Expand All @@ -37,8 +39,10 @@ class CachingGraphRunner {
public:
CachingGraphRunner() = default;

void addGraph(const torch::jit::Node *node);

/// Given a PyTorch glow::CompilationGroup Node \p node that contains a
/// PyTorch subgraph and corresponding PyTorch Stack \p stack of inputs, run
/// that subgraph on those inputs. If this is the first time this node has
/// been seen then this first loads it as a Glow Function and compiles.
void runGraph(const torch::jit::Node *node, torch::jit::Stack &stack);
};

Expand Down

0 comments on commit 408cd8d

Please sign in to comment.