Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
15 changed files
with
720 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,10 @@ | |
.DS_Store | ||
*.so | ||
*.dylib | ||
*.egg-info | ||
*.egg | ||
*.eggs | ||
*.pyc | ||
|
||
GPATH | ||
GRTAGS | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import torch | ||
import torch_glow | ||
|
||
x = torch.randn(4) | ||
y = torch.randn(4) | ||
|
||
@torch.jit.script | ||
def foo(a, b): | ||
c = a.mul(b) | ||
a = c.mul(c) | ||
a = c.mul(a) | ||
d = c.div(a) | ||
return d | ||
|
||
print("original jit ir") | ||
print(foo.graph_for(x, y)) | ||
|
||
jit_res = foo(x, y) | ||
|
||
torch_glow.enableFusionPass() | ||
|
||
@torch.jit.script | ||
def foo_glow(a, b): | ||
return foo(a, b) | ||
|
||
print("glow jit ir") | ||
print(foo_glow.graph_for(x, y)) | ||
|
||
jit_glow_res = foo_glow(x, y) | ||
|
||
print("jit_res") | ||
print(jit_res) | ||
print("jit_glow_res") | ||
print(jit_glow_res) | ||
|
||
assert torch.allclose(jit_res, jit_glow_res) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
[aliases] | ||
test=pytest | ||
|
||
[tool:pytest] | ||
testpaths = tests |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
from __future__ import unicode_literals | ||
|
||
from distutils.spawn import find_executable | ||
from distutils import sysconfig, log | ||
import setuptools | ||
import setuptools.command.build_py | ||
import setuptools.command.develop | ||
import setuptools.command.build_ext | ||
|
||
from collections import namedtuple | ||
from contextlib import contextmanager | ||
import glob | ||
import os | ||
import shlex | ||
import subprocess | ||
import sys | ||
import argparse | ||
from textwrap import dedent | ||
import multiprocessing | ||
|
||
try: | ||
import torch | ||
except ImportError as e: | ||
print('Unable to import torch. Error:') | ||
print('\t', e) | ||
print('You need to install pytorch first.') | ||
sys.exit(1) | ||
|
||
|
||
FILE_DIR = os.path.realpath(os.path.dirname(__file__)) | ||
TOP_DIR = os.path.realpath(os.path.dirname(FILE_DIR)) | ||
CMAKE_BUILD_DIR = os.path.join(TOP_DIR, 'build') | ||
|
||
CMAKE = find_executable('cmake') or find_executable('cmake3') | ||
if not CMAKE: | ||
print('Could not find "cmake". Make sure it is in your PATH.') | ||
sys.exit(1) | ||
|
||
install_requires = [] | ||
setup_requires = [] | ||
tests_require = [] | ||
extras_require = {} | ||
|
||
# ################################################################################ | ||
# # Flags | ||
# ################################################################################ | ||
|
||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--debug", type=bool, default=False, help="Compile with debug on") | ||
parser.add_argument("--cmake_prefix_path", type=str, help="Populates -DCMAKE_PREFIX_PATH") | ||
args = parser.parse_known_args()[0] | ||
|
||
|
||
# ################################################################################ | ||
# # Utilities | ||
# ################################################################################ | ||
|
||
|
||
@contextmanager | ||
def cd(path): | ||
if not os.path.isabs(path): | ||
raise RuntimeError('Can only cd to absolute path, got: {}'.format(path)) | ||
orig_path = os.getcwd() | ||
os.chdir(path) | ||
try: | ||
yield | ||
finally: | ||
os.chdir(orig_path) | ||
|
||
|
||
# ################################################################################ | ||
# # Customized commands | ||
# ################################################################################ | ||
|
||
|
||
class cmake_build(setuptools.Command): | ||
""" | ||
Compiles everything when `python setup.py develop` is run using cmake. | ||
Custom args can be passed to cmake by specifying the `CMAKE_ARGS` | ||
environment variable. | ||
""" | ||
def initialize_options(self): | ||
pass | ||
|
||
def finalize_options(self): | ||
pass | ||
|
||
def _run_cmake(self): | ||
with cd(CMAKE_BUILD_DIR): | ||
cmake_args = [ | ||
CMAKE, | ||
'-DGLOW_BUILD_PYTORCH_INTEGRATION=ON', | ||
'-DBUILD_SHARED_LIBS=OFF', | ||
'-DCMAKE_EXPORT_COMPILE_COMMANDS=ON', | ||
'-DCMAKE_BUILD_TYPE={}'.format('Debug' if args.debug else 'Release'), | ||
'-DPYTHON_EXECUTABLE={}'.format(sys.executable), | ||
# PyTorch cmake args | ||
'-DPYTORCH_DIR={}'.format( | ||
os.path.dirname(os.path.realpath(torch.__file__))), | ||
] | ||
|
||
if args.cmake_prefix_path: | ||
cmake_args.append('-DCMAKE_PREFIX_PATH={}'.format(args.cmake_prefix_path)) | ||
|
||
if 'CMAKE_ARGS' in os.environ: | ||
extra_cmake_args = shlex.split(os.environ['CMAKE_ARGS']) | ||
log.info('Extra cmake args: {}'.format(extra_cmake_args)) | ||
cmake_args.extend(extra_cmake_args) | ||
cmake_args.append(TOP_DIR) | ||
subprocess.check_call(cmake_args) | ||
|
||
def _run_build(self): | ||
with cd(CMAKE_BUILD_DIR): | ||
build_args = [ | ||
CMAKE, | ||
'--build', os.curdir, | ||
'--', '-j', str(multiprocessing.cpu_count()), | ||
] | ||
subprocess.check_call(build_args) | ||
|
||
def run(self): | ||
is_initial_build = not os.path.exists(CMAKE_BUILD_DIR) | ||
if is_initial_build: | ||
os.makedirs(CMAKE_BUILD_DIR) | ||
self._run_cmake() | ||
self._run_build() | ||
|
||
|
||
class develop(setuptools.command.develop.develop): | ||
def run(self): | ||
self.run_command('build_ext') | ||
setuptools.command.develop.develop.run(self) | ||
|
||
|
||
class build_ext(setuptools.command.build_ext.build_ext): | ||
def run(self): | ||
self.run_command('cmake_build') | ||
setuptools.command.build_ext.build_ext.run(self) | ||
|
||
def build_extensions(self): | ||
for ext in self.extensions: | ||
fullname = self.get_ext_fullname(ext.name) | ||
filename = os.path.basename(self.get_ext_filename(fullname)) | ||
|
||
src = os.path.join(CMAKE_BUILD_DIR, "torch_glow", "src", filename) | ||
dst = os.path.join(os.path.realpath(self.build_lib), 'torch_glow', filename) | ||
print("dst", dst) | ||
if not os.path.exists(os.path.dirname(dst)): | ||
os.makedirs(os.path.dirname(dst)) | ||
self.copy_file(src, dst) | ||
|
||
|
||
cmdclass = { | ||
'cmake_build': cmake_build, | ||
'develop': develop, | ||
'build_ext': build_ext, | ||
} | ||
|
||
# ################################################################################ | ||
# # Extensions | ||
# ################################################################################ | ||
|
||
ext_modules = [ | ||
setuptools.Extension( | ||
name=str('torch_glow._torch_glow'), | ||
sources=[]) | ||
] | ||
|
||
# ################################################################################ | ||
# # Packages | ||
# ################################################################################ | ||
|
||
# # no need to do fancy stuff so far | ||
packages = setuptools.find_packages() | ||
|
||
# ################################################################################ | ||
# # Test | ||
# ################################################################################ | ||
|
||
setup_requires.append('pytest-runner') | ||
tests_require.append('pytest') | ||
|
||
# ################################################################################ | ||
# # Final | ||
# ################################################################################ | ||
|
||
setuptools.setup( | ||
name="torch_glow", | ||
description="PyTorch + Glow", | ||
ext_modules=ext_modules, | ||
cmdclass=cmdclass, | ||
packages=packages, | ||
include_package_data=True, | ||
install_requires=install_requires, | ||
setup_requires=setup_requires, | ||
tests_require=tests_require, | ||
extras_require=extras_require, | ||
author='jackm321', | ||
author_email='jackmontgomery@fb.com', | ||
url='https://github.com/pytorch/glow', | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# PYTORCH_DIR | ||
if(DEFINED ENV{PYTORCH_DIR}) | ||
SET(PYTORCH_DIR $ENV{PYTORCH_DIR}) | ||
message(STATUS "Using PYTORCH_DIR from env") | ||
endif() | ||
|
||
if(NOT EXISTS "${PYTORCH_DIR}") | ||
message(FATAL_ERROR "No PyTorch installation found") | ||
endif() | ||
|
||
message(STATUS "Using pytorch dir ${PYTORCH_DIR}") | ||
|
||
link_directories(${PYTORCH_DIR}/lib) | ||
|
||
pybind11_add_module(_torch_glow | ||
binding.cpp | ||
PyTorchModelLoader.cpp | ||
CachingGraphRunner.cpp) | ||
|
||
target_compile_options(_torch_glow PRIVATE -frtti -fexceptions) | ||
target_link_libraries(_torch_glow | ||
PRIVATE | ||
# pytorch | ||
torch | ||
caffe2 | ||
c10 | ||
# glow | ||
Support | ||
ExecutionEngine | ||
Graph | ||
Importer | ||
Interpreter | ||
Backends | ||
# pybind | ||
pybind11) | ||
|
||
target_include_directories(_torch_glow PUBLIC | ||
${PYTORCH_DIR}/include | ||
) |
Oops, something went wrong.