Skip to content

Commit

Permalink
Setup pytorch integration point
Browse files Browse the repository at this point in the history
  • Loading branch information
jackm321 committed Jun 21, 2019
1 parent 3ac2bac commit ee02eb2
Show file tree
Hide file tree
Showing 15 changed files with 720 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Expand Up @@ -6,6 +6,10 @@
.DS_Store
*.so
*.dylib
*.egg-info
*.egg
*.eggs
*.pyc

GPATH
GRTAGS
Expand Down
3 changes: 3 additions & 0 deletions .gitmodules
Expand Up @@ -16,3 +16,6 @@
[submodule "googlebenchmark"]
path = tests/googlebenchmark
url = https://github.com/google/benchmark
[submodule "thirdparty/pybind11"]
path = thirdparty/pybind11
url = https://github.com/pybind/pybind11.git
10 changes: 10 additions & 0 deletions CMakeLists.txt
Expand Up @@ -10,6 +10,7 @@ option(GLOW_WITH_LLVMIRCODEGEN "Build the LLVM-based code generation library" ON
option(GLOW_WITH_OPENCL "Build the OpenCL backend" OFF)
option(GLOW_WITH_HABANA "Build the Habana backend" OFF)
option(GLOW_BUILD_EXAMPLES "Build the examples" ON)
option(GLOW_BUILD_PYTORCH_INTEGRATION "Build integration for PyTorch" OFF)
option(GLOW_BUILD_TESTS "Build the tests" ON)
option(GLOW_WITH_BUNDLES "Build bundles" OFF)
option(LINK_PROTOBUF_AS_DLL "Link against protobuf build as dynamic libray." OFF)
Expand Down Expand Up @@ -176,6 +177,15 @@ if (GLOW_BUILD_EXAMPLES)
add_subdirectory(examples)
endif()

if (GLOW_BUILD_PYTORCH_INTEGRATION)
if(NOT EXISTS "${GLOW_THIRDPARTY_DIR}/pybind11")
message(FATAL_ERROR "No pybind11 git submodule. Run: git submodule update --init --recursive")
endif()

add_subdirectory(${GLOW_THIRDPARTY_DIR}/pybind11)
add_subdirectory(torch_glow/src)
endif()

if (GLOW_WITH_BUNDLES AND NOT GLOW_WITH_CPU)
message(FATAL_ERROR "Cannot create bundles without CPU backend. Configure with -DGLOW_WITH_BUNDLES and -DGLOW_WITH_CPU to build bundles.")
endif()
Expand Down
1 change: 1 addition & 0 deletions thirdparty/pybind11
Submodule pybind11 added at a1b71d
36 changes: 36 additions & 0 deletions torch_glow/examples/example.py
@@ -0,0 +1,36 @@
import torch
import torch_glow

x = torch.randn(4)
y = torch.randn(4)

@torch.jit.script
def foo(a, b):
c = a.mul(b)
a = c.mul(c)
a = c.mul(a)
d = c.div(a)
return d

print("original jit ir")
print(foo.graph_for(x, y))

jit_res = foo(x, y)

torch_glow.enableFusionPass()

@torch.jit.script
def foo_glow(a, b):
return foo(a, b)

print("glow jit ir")
print(foo_glow.graph_for(x, y))

jit_glow_res = foo_glow(x, y)

print("jit_res")
print(jit_res)
print("jit_glow_res")
print(jit_glow_res)

assert torch.allclose(jit_res, jit_glow_res)
5 changes: 5 additions & 0 deletions torch_glow/setup.cfg
@@ -0,0 +1,5 @@
[aliases]
test=pytest

[tool:pytest]
testpaths = tests
206 changes: 206 additions & 0 deletions torch_glow/setup.py
@@ -0,0 +1,206 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from distutils.spawn import find_executable
from distutils import sysconfig, log
import setuptools
import setuptools.command.build_py
import setuptools.command.develop
import setuptools.command.build_ext

from collections import namedtuple
from contextlib import contextmanager
import glob
import os
import shlex
import subprocess
import sys
import argparse
from textwrap import dedent
import multiprocessing

try:
import torch
except ImportError as e:
print('Unable to import torch. Error:')
print('\t', e)
print('You need to install pytorch first.')
sys.exit(1)


FILE_DIR = os.path.realpath(os.path.dirname(__file__))
TOP_DIR = os.path.realpath(os.path.dirname(FILE_DIR))
CMAKE_BUILD_DIR = os.path.join(TOP_DIR, 'build')

CMAKE = find_executable('cmake') or find_executable('cmake3')
if not CMAKE:
print('Could not find "cmake". Make sure it is in your PATH.')
sys.exit(1)

install_requires = []
setup_requires = []
tests_require = []
extras_require = {}

# ################################################################################
# # Flags
# ################################################################################


parser = argparse.ArgumentParser()
parser.add_argument("--debug", type=bool, default=False, help="Compile with debug on")
parser.add_argument("--cmake_prefix_path", type=str, help="Populates -DCMAKE_PREFIX_PATH")
args = parser.parse_known_args()[0]


# ################################################################################
# # Utilities
# ################################################################################


@contextmanager
def cd(path):
if not os.path.isabs(path):
raise RuntimeError('Can only cd to absolute path, got: {}'.format(path))
orig_path = os.getcwd()
os.chdir(path)
try:
yield
finally:
os.chdir(orig_path)


# ################################################################################
# # Customized commands
# ################################################################################


class cmake_build(setuptools.Command):
"""
Compiles everything when `python setup.py develop` is run using cmake.
Custom args can be passed to cmake by specifying the `CMAKE_ARGS`
environment variable.
"""
def initialize_options(self):
pass

def finalize_options(self):
pass

def _run_cmake(self):
with cd(CMAKE_BUILD_DIR):
cmake_args = [
CMAKE,
'-DGLOW_BUILD_PYTORCH_INTEGRATION=ON',
'-DBUILD_SHARED_LIBS=OFF',
'-DCMAKE_EXPORT_COMPILE_COMMANDS=ON',
'-DCMAKE_BUILD_TYPE={}'.format('Debug' if args.debug else 'Release'),
'-DPYTHON_EXECUTABLE={}'.format(sys.executable),
# PyTorch cmake args
'-DPYTORCH_DIR={}'.format(
os.path.dirname(os.path.realpath(torch.__file__))),
]

if args.cmake_prefix_path:
cmake_args.append('-DCMAKE_PREFIX_PATH={}'.format(args.cmake_prefix_path))

if 'CMAKE_ARGS' in os.environ:
extra_cmake_args = shlex.split(os.environ['CMAKE_ARGS'])
log.info('Extra cmake args: {}'.format(extra_cmake_args))
cmake_args.extend(extra_cmake_args)
cmake_args.append(TOP_DIR)
subprocess.check_call(cmake_args)

def _run_build(self):
with cd(CMAKE_BUILD_DIR):
build_args = [
CMAKE,
'--build', os.curdir,
'--', '-j', str(multiprocessing.cpu_count()),
]
subprocess.check_call(build_args)

def run(self):
is_initial_build = not os.path.exists(CMAKE_BUILD_DIR)
if is_initial_build:
os.makedirs(CMAKE_BUILD_DIR)
self._run_cmake()
self._run_build()


class develop(setuptools.command.develop.develop):
def run(self):
self.run_command('build_ext')
setuptools.command.develop.develop.run(self)


class build_ext(setuptools.command.build_ext.build_ext):
def run(self):
self.run_command('cmake_build')
setuptools.command.build_ext.build_ext.run(self)

def build_extensions(self):
for ext in self.extensions:
fullname = self.get_ext_fullname(ext.name)
filename = os.path.basename(self.get_ext_filename(fullname))

src = os.path.join(CMAKE_BUILD_DIR, "torch_glow", "src", filename)
dst = os.path.join(os.path.realpath(self.build_lib), 'torch_glow', filename)
print("dst", dst)
if not os.path.exists(os.path.dirname(dst)):
os.makedirs(os.path.dirname(dst))
self.copy_file(src, dst)


cmdclass = {
'cmake_build': cmake_build,
'develop': develop,
'build_ext': build_ext,
}

# ################################################################################
# # Extensions
# ################################################################################

ext_modules = [
setuptools.Extension(
name=str('torch_glow._torch_glow'),
sources=[])
]

# ################################################################################
# # Packages
# ################################################################################

# # no need to do fancy stuff so far
packages = setuptools.find_packages()

# ################################################################################
# # Test
# ################################################################################

setup_requires.append('pytest-runner')
tests_require.append('pytest')

# ################################################################################
# # Final
# ################################################################################

setuptools.setup(
name="torch_glow",
description="PyTorch + Glow",
ext_modules=ext_modules,
cmdclass=cmdclass,
packages=packages,
include_package_data=True,
install_requires=install_requires,
setup_requires=setup_requires,
tests_require=tests_require,
extras_require=extras_require,
author='jackm321',
author_email='jackmontgomery@fb.com',
url='https://github.com/pytorch/glow',
)
39 changes: 39 additions & 0 deletions torch_glow/src/CMakeLists.txt
@@ -0,0 +1,39 @@
# PYTORCH_DIR
if(DEFINED ENV{PYTORCH_DIR})
SET(PYTORCH_DIR $ENV{PYTORCH_DIR})
message(STATUS "Using PYTORCH_DIR from env")
endif()

if(NOT EXISTS "${PYTORCH_DIR}")
message(FATAL_ERROR "No PyTorch installation found")
endif()

message(STATUS "Using pytorch dir ${PYTORCH_DIR}")

link_directories(${PYTORCH_DIR}/lib)

pybind11_add_module(_torch_glow
binding.cpp
PyTorchModelLoader.cpp
CachingGraphRunner.cpp)

target_compile_options(_torch_glow PRIVATE -frtti -fexceptions)
target_link_libraries(_torch_glow
PRIVATE
# pytorch
torch
caffe2
c10
# glow
Support
ExecutionEngine
Graph
Importer
Interpreter
Backends
# pybind
pybind11)

target_include_directories(_torch_glow PUBLIC
${PYTORCH_DIR}/include
)

0 comments on commit ee02eb2

Please sign in to comment.