diff --git a/docs/pytorch.md b/docs/pytorch.md new file mode 100644 index 0000000000..a8d298da38 --- /dev/null +++ b/docs/pytorch.md @@ -0,0 +1,26 @@ +# Loading PyTorch models in Glow +**Warning:** PyTorch integration is still under development and does not yet have as much support as Caffe2 model loading. + +## About +Import PyTorch models to Glow via the [PyTorch JIT IR](https://pytorch.org/docs/master/jit.html). + +See `glow/torch_glow/examples` for illustrative examples. + + +## Setup +* Follow directions in Building.md to make sure Glow can be built +* install [PyTorch](https://pytorch.org/) nightly +* install torchvision: `pip install torchvision` +* cd to `glow/torch_glow` + +## Usage +### Run tests +* `python setup.py test` + * use the `--cmake_prefix_path` flag to specify an llvm install location just like when building glow + * to disable capturing test outputs, add `addopts = -s` to `[tool:pytest]` in setup.cfg +### Temporarily install while developing on Glow +* `python setup.py develop` + * verify with installation worked with `import torch_glow` in Python +### Install +* `python setup.py install` + * verify with installation worked with `import torch_glow` in Python \ No newline at end of file diff --git a/torch_glow/examples/example.py b/torch_glow/examples/basic_example.py similarity index 100% rename from torch_glow/examples/example.py rename to torch_glow/examples/basic_example.py diff --git a/torch_glow/setup.cfg b/torch_glow/setup.cfg index 601a2f3fff..b3481a4ff6 100644 --- a/torch_glow/setup.cfg +++ b/torch_glow/setup.cfg @@ -3,3 +3,4 @@ test=pytest [tool:pytest] testpaths = tests +addopts = --verbose \ No newline at end of file diff --git a/torch_glow/setup.py b/torch_glow/setup.py index cf90964696..908783e688 100644 --- a/torch_glow/setup.py +++ b/torch_glow/setup.py @@ -48,11 +48,20 @@ # # Flags # ################################################################################ +# store first argument +assert len(sys.argv) > 0 +first_arg = sys.argv[0] +# parse known arguments parser = argparse.ArgumentParser() -parser.add_argument("--debug", type=bool, default=False, help="Compile with debug on") +parser.add_argument("--run_cmake", action='store_true', default=False, help="Run cmake") +parser.add_argument("--release", action='store_true', default=False, help="Compile with debug on") parser.add_argument("--cmake_prefix_path", type=str, help="Populates -DCMAKE_PREFIX_PATH") -args = parser.parse_known_args()[0] + +# restore first and remaining arguments to argv +arg_parse_res = parser.parse_known_args() +args = arg_parse_res[0] +sys.argv = [first_arg] + arg_parse_res[1] # ################################################################################ @@ -97,7 +106,7 @@ def _run_cmake(self): '-DGLOW_BUILD_PYTORCH_INTEGRATION=ON', '-DBUILD_SHARED_LIBS=OFF', '-DCMAKE_EXPORT_COMPILE_COMMANDS=ON', - '-DCMAKE_BUILD_TYPE={}'.format('Debug' if args.debug else 'Release'), + '-DCMAKE_BUILD_TYPE={}'.format('Release' if args.release else 'Debug'), '-DPYTHON_EXECUTABLE={}'.format(sys.executable), # PyTorch cmake args '-DPYTORCH_DIR={}'.format( @@ -127,7 +136,8 @@ def run(self): is_initial_build = not os.path.exists(CMAKE_BUILD_DIR) if is_initial_build: os.makedirs(CMAKE_BUILD_DIR) - self._run_cmake() + if is_initial_build or args.run_cmake: + self._run_cmake() self._run_build() diff --git a/torch_glow/src/CMakeLists.txt b/torch_glow/src/CMakeLists.txt index 54075563ca..11cf779f1c 100644 --- a/torch_glow/src/CMakeLists.txt +++ b/torch_glow/src/CMakeLists.txt @@ -1,4 +1,3 @@ -# PYTORCH_DIR if(DEFINED ENV{PYTORCH_DIR}) SET(PYTORCH_DIR $ENV{PYTORCH_DIR}) message(STATUS "Using PYTORCH_DIR from env") @@ -20,18 +19,14 @@ pybind11_add_module(_torch_glow target_compile_options(_torch_glow PRIVATE -frtti -fexceptions) target_link_libraries(_torch_glow PRIVATE - # pytorch torch - caffe2 c10 - # glow Support ExecutionEngine Graph Importer Interpreter Backends - # pybind pybind11) target_include_directories(_torch_glow PUBLIC diff --git a/torch_glow/src/CachingGraphRunner.cpp b/torch_glow/src/CachingGraphRunner.cpp index 86a7b29e75..02471d5774 100644 --- a/torch_glow/src/CachingGraphRunner.cpp +++ b/torch_glow/src/CachingGraphRunner.cpp @@ -23,16 +23,15 @@ GraphInfo::GraphInfo(const torch::jit::Node *node) : subgraph(node->g(at::attr::Subgraph)) {} -void CachingGraphRunner::addGraph(const torch::jit::Node *node) { - // TODO: just use node* as key - GraphInfo graphInfo(node); - jitNodeToInfoMap_.insert({node, std::move(graphInfo)}); -} - void CachingGraphRunner::runGraph(const torch::jit::Node *node, torch::jit::Stack &stack) { - // Make sure this is a valid id - assert(jitNodeToInfoMap_.count(node) > 0); + // If this is the first time this subgraph has been run then create a new + // GraphInfo object to store information about it. + if (!jitNodeToInfoMap_.count(node)) { + GraphInfo graphInfo(node); + jitNodeToInfoMap_.insert({node, std::move(graphInfo)}); + } + auto &graphInfo = jitNodeToInfoMap_.at(node); const at::ArrayRef &graphInputs = diff --git a/torch_glow/src/CachingGraphRunner.h b/torch_glow/src/CachingGraphRunner.h index 6f9eddf481..f02cd6f77d 100644 --- a/torch_glow/src/CachingGraphRunner.h +++ b/torch_glow/src/CachingGraphRunner.h @@ -28,6 +28,8 @@ struct GraphInfo { GraphInfo(const torch::jit::Node *node); }; +/// Responsible for maintaining a mapping from PyTorch subgraphs and their +/// unique input types to compiled Glow Functions. class CachingGraphRunner { /// Map of from PyTorch JIT Node containing contracted JIT subgraph to /// to GraphInfo containing information relevent to Glow about that subgraph. @@ -37,8 +39,10 @@ class CachingGraphRunner { public: CachingGraphRunner() = default; - void addGraph(const torch::jit::Node *node); - + /// Given a PyTorch glow::CompilationGroup Node \p node that contains a + /// PyTorch subgraph and corresponding PyTorch Stack \p stack of inputs, run + /// that subgraph on those inputs. If this is the first time this node has + /// been seen then this first loads it as a Glow Function and compiles. void runGraph(const torch::jit::Node *node, torch::jit::Stack &stack); }; diff --git a/torch_glow/src/PyTorchModelLoader.cpp b/torch_glow/src/PyTorchModelLoader.cpp index f22ecda9e9..b327f2a25e 100644 --- a/torch_glow/src/PyTorchModelLoader.cpp +++ b/torch_glow/src/PyTorchModelLoader.cpp @@ -15,8 +15,11 @@ */ #include "PyTorchModelLoader.h" +#include namespace { +/// \returns a corresponding Glow Type for a given PyTorch CompleteTensorType \p +/// ptType. glow::Type ptTypeToGlowType(const at::CompleteTensorType &ptType) { // TODO: get correct ElemKind std::vector dims; @@ -25,47 +28,357 @@ glow::Type ptTypeToGlowType(const at::CompleteTensorType &ptType) { } return glow::Type(glow::ElemKind::FloatTy, dims); } + +/// Downcast a double to a float. +float to32Bit(double val) { + assert(val <= std::numeric_limits::max()); + assert(val >= std::numeric_limits::lowest()); + return static_cast(val); +} + +/// Downcast an int64_t to a int32_t. +int32_t to32Bit(int64_t val) { + assert(val <= std::numeric_limits::max()); + assert(val >= std::numeric_limits::lowest()); + return static_cast(val); +} + +/// Given a Glow Node \p glowNode which must be a Glow Constant, returns a Glow +/// Handle to that Constant's payload. +template +const glow::Handle getGlowConstantPayload(const glow::Node *glowNode) { + assert(glowNode); + const glow::Constant *glowConstant = llvm::dyn_cast(glowNode); + assert(glowConstant); + if (std::is_same::value) { + assert(glowConstant->getElementType() == glow::ElemKind::Int32ITy); + } else if (std::is_same::value) { + assert(glowConstant->getElementType() == glow::ElemKind::FloatTy); + } else if (std::is_same::value) { + assert(glowConstant->getElementType() == glow::ElemKind::BoolTy); + } else { + assert(false && "Unsupported type"); + } + return glowConstant->getPayload().getHandle(); +} + +/// Given a Glow Handle \p handle and \p targetSize, will return an std::vector +/// of the elements in the Tensor the Handle manages if the size of that Tensor +/// is targetSize or if the Tensor has only one element in it, will replicate +/// that element targetSize times. +template +std::vector expandParamIfNeeded(const glow::Handle &handle, + size_t targetSize) { + assert(handle.dims().size() == 1 && "Expected a 1d Glow Tensor"); + assert(handle.size() == 1 || + handle.size() == targetSize && + "Expected input size to be either 1 or the target size"); + std::vector out; + out.resize(targetSize); + for (size_t i = 0; i < targetSize; ++i) { + out[i] = + static_cast(handle.size() == 1 ? handle.raw(0) : handle.raw(i)); + } + return out; +} + +/// Given a Glow Handle \p Handle returns the first element in the Tensor +/// of that Handle, asserting it is equivalent to all other elements in the +/// Tensor. +template T contractParamIfNeeded(const glow::Handle &handle) { + assert(handle.size() > 0); + const T firstElem = handle.raw(0); + for (size_t i = 1; i < handle.size(); i++) { + assert(handle.raw(i) == firstElem); + } + return firstElem; +} } // namespace -glow::Placeholder *PyTorchModelLoader::loadValue(const torch::jit::Value *val) { - // TODO: do we need to care about optional IValues here? - assert(val->isCompleteTensor()); - auto ptType = val->type()->cast(); +glow::NodeValue +PyTorchModelLoader::getGlowNodeValue(const torch::jit::Value *value) const { + const auto it = valueMap_.find(value); + assert(it != valueMap_.end()); + return it->second; +} + +bool PyTorchModelLoader::hasGlowNodeValue( + const torch::jit::Value *value) const { + const auto it = valueMap_.find(value); + return it != valueMap_.end(); +} + +void PyTorchModelLoader::addGlowNodeValue(const torch::jit::Value *value, + glow::NodeValue nodeValue) { + assert(valueMap_.count(value) == 0); + valueMap_[value] = nodeValue; +} + +template +glow::Handle PyTorchModelLoader::getGlowConstantHandle( + const torch::jit::Value *value) const { + glow::NodeValue nodeValue = getGlowNodeValue(value); + return getGlowConstantPayload(nodeValue.getNode()); +} + +glow::Placeholder * +PyTorchModelLoader::loadValue(const torch::jit::Value *value) { + assert(value->isCompleteTensor()); + auto ptType = value->type()->cast(); auto glowType = ptTypeToGlowType(*ptType.get()); glow::Placeholder *ph = mod_.createPlaceholder(&glowType, "input", /*isTrainable*/ false); - valueMap_[val] = ph->getOutput(); + addGlowNodeValue(value, ph->getOutput()); return ph; } -void PyTorchModelLoader::loadNode(const torch::jit::Node *ptNode) { - glow::Node *glowNode = nullptr; +void PyTorchModelLoader::loadMul(const torch::jit::Node *ptNode) { + auto inputs = ptNode->inputs(); + auto outputs = ptNode->outputs(); + assert(inputs.size() == 2); + assert(outputs.size() == 1); + + glow::NodeValue rhs = getGlowNodeValue(inputs[0]); + glow::NodeValue lhs = getGlowNodeValue(inputs[1]); + glow::MulNode *glowNode = f_->createMul("mul", rhs, lhs); + addGlowNodeValue(outputs[0], glowNode->getResult()); +} +void PyTorchModelLoader::loadDiv(const torch::jit::Node *ptNode) { auto inputs = ptNode->inputs(); auto outputs = ptNode->outputs(); + assert(inputs.size() == 2); + assert(outputs.size() == 1); - switch (ptNode->kind()) { - case at::aten::mul: { - assert(inputs.size() == 2); - glow::NodeValue rhs = valueMap_.at(inputs[0]); - glow::NodeValue lhs = valueMap_.at(inputs[1]); - glowNode = f_->createMul("mul", rhs, lhs); - valueMap_[outputs[0]] = glowNode->getNthResult(glow::MulNode::ResultIdx); - break; + glow::NodeValue rhs = getGlowNodeValue(inputs[0]); + glow::NodeValue lhs = getGlowNodeValue(inputs[1]); + glow::DivNode *glowNode = f_->createDiv("div", rhs, lhs); + addGlowNodeValue(outputs[0], glowNode->getResult()); +} + +void PyTorchModelLoader::loadAdd(const torch::jit::Node *ptNode) { + auto inputs = ptNode->inputs(); + auto outputs = ptNode->outputs(); + assert(inputs.size() == 3); + assert(outputs.size() == 1); + + // TODO: extend this to allow non-constant scalars. + // Check scalar is 1, any other value is not supported. + const auto scalarHandle = getGlowConstantHandle(inputs[2]); + assert(scalarHandle.size() == 1); + assert(scalarHandle.raw(0) == 1); + + glow::NodeValue rhs = getGlowNodeValue(inputs[0]); + glow::NodeValue lhs = getGlowNodeValue(inputs[1]); + glow::AddNode *glowNode = f_->createAdd("add", rhs, lhs); + addGlowNodeValue(outputs[0], glowNode->getResult()); +} + +void PyTorchModelLoader::loadSub(const torch::jit::Node *ptNode) { + auto inputs = ptNode->inputs(); + auto outputs = ptNode->outputs(); + assert(inputs.size() == 3); + assert(outputs.size() == 1); + + // TODO: extend this to allow non-constant scalars. + // Check scalar is 1, any other value is not supported. + const auto scalarHandle = getGlowConstantHandle(inputs[2]); + assert(scalarHandle.size() == 1); + assert(scalarHandle.raw(0) == 1); + + glow::NodeValue rhs = getGlowNodeValue(inputs[0]); + glow::NodeValue lhs = getGlowNodeValue(inputs[1]); + glow::SubNode *glowNode = f_->createSub("sub", rhs, lhs); + addGlowNodeValue(outputs[0], glowNode->getResult()); +} + +void PyTorchModelLoader::loadRelu(const torch::jit::Node *ptNode) { + auto inputs = ptNode->inputs(); + auto outputs = ptNode->outputs(); + assert(inputs.size() == 1); + assert(outputs.size() == 1); + + glow::NodeValue input = getGlowNodeValue(inputs[0]); + glow::ReluNode *glowNode = f_->createRELU("relu", input); + addGlowNodeValue(outputs[0], glowNode->getResult()); +} + +void PyTorchModelLoader::loadConvolution(const torch::jit::Node *ptNode) { + auto inputs = ptNode->inputs(); + auto outputs = ptNode->outputs(); + assert(inputs.size() == 12); + assert(outputs.size() == 1); + + // Glow expects conv inputs to be in NHWC but PyTorch keeps them in NCHW so we + // tranpose them. + glow::NodeValue input = getGlowNodeValue(inputs[0]); + input = f_->createTranspose("conv_input_transposed", input, NCHW2NHWC); + glow::ShapeNHWC inputShape(input.dims()); + + // Glow expects conv weights to be in CRSK but PyTorch keeps them in CKRS + // so we tranpose them. C - output_depth, R - filter_height, S - + // filter_width, K - input_depth. + glow::NodeValue weights = getGlowNodeValue(inputs[1]); + weights = f_->createTranspose("conv_weights_transposed", weights, NCHW2NHWC); + glow::ShapeNHWC weightsShape(weights.dims()); + + // If a bias was provided then use it otherwise create a 0 bias. + glow::NodeValue bias; + if (hasGlowNodeValue(inputs[2])) { + bias = getGlowNodeValue(inputs[2]); + } else { + glow::Tensor biasT(glow::ElemKind::FloatTy, {weightsShape.n}); + biasT.zero(); + glow::Constant *biasConstant = + f_->getParent()->createConstant("conv_bias", std::move(biasT)); + bias = biasConstant->getOutput(); } - case at::aten::div: { - assert(inputs.size() == 2); - glow::NodeValue rhs = valueMap_.at(inputs[0]); - glow::NodeValue lhs = valueMap_.at(inputs[1]); - glowNode = f_->createDiv("div", rhs, lhs); - valueMap_[outputs[0]] = glowNode->getNthResult(glow::DivNode::ResultIdx); - break; + + glow::Handle stridesHandle = + getGlowConstantHandle(inputs[3]); + std::vector strides = + expandParamIfNeeded(stridesHandle, 2); + + glow::Handle padsHandle = getGlowConstantHandle(inputs[4]); + uint32_t pad = static_cast(contractParamIfNeeded(padsHandle)); + std::vector pads = {pad, pad, pad, pad}; + + glow::Handle dilationHandle = + getGlowConstantHandle(inputs[5]); + uint32_t dilation = + static_cast(contractParamIfNeeded(dilationHandle)); + + // Don't support transposed convolutions yet. + glow::Handle transposedHandle = getGlowConstantHandle(inputs[6]); + assert(transposedHandle.size() == 1 && !transposedHandle.raw(0) && + "Transposed convolutions not supported"); + (void)transposedHandle; + + glow::Handle groupsHandle = + getGlowConstantHandle(inputs[8]); + assert(groupsHandle.size() == 1); + uint32_t groups = groupsHandle.raw(0); + + std::vector kernels = {static_cast(weightsShape.h), + static_cast(weightsShape.w)}; + + auto outSz = glow::calculateConvPoolOutputDims( + inputShape.h, inputShape.w, kernels, strides, pads, dilation); + std::array outDims = { + {input.dims()[0], outSz.first, outSz.second, weightsShape.n}}; + glow::TypeRef outTy = + f_->getParent()->uniqueType(glow::ElemKind::FloatTy, outDims); + + glow::ConvolutionNode *conv = + f_->createConv("conv", input, weights, bias, outTy, kernels, strides, + pads, groups, dilation); + glow::TransposeNode *output = f_->createTranspose( + "conv_input_transposed", conv->getResult(), NHWC2NCHW); + + addGlowNodeValue(outputs[0], output->getResult()); +} + +void PyTorchModelLoader::loadConstant(const torch::jit::Node *ptNode) { + auto inputs = ptNode->inputs(); + auto outputs = ptNode->outputs(); + assert(inputs.size() == 0); + assert(outputs.size() == 1); + + auto optionalIValue = torch::jit::toIValue(outputs[0]); + assert(optionalIValue.has_value() && "Constants should have IValue outputs"); + const torch::jit::IValue iVal = *optionalIValue; + + glow::Module *mod = f_->getParent(); + + glow::Tensor t; + + if (iVal.isInt()) { + int64_t val = iVal.toInt(); + t.reset(glow::ElemKind::Int32ITy, {1}); + t.getHandle().raw(0) = to32Bit(val); + } else if (iVal.isIntList()) { + const auto vals = iVal.toIntListRef(); + t.reset(glow::ElemKind::Int32ITy, {vals.size()}); + auto handle = t.getHandle(); + for (size_t i = 0; i < vals.size(); ++i) { + handle.raw(i) = to32Bit(vals[i]); + } + } else if (iVal.isDouble()) { + double val = iVal.toDouble(); + t.reset(glow::ElemKind::FloatTy, {1}); + t.getHandle().raw(0) = to32Bit(val); + } else if (iVal.isDoubleList()) { + const auto vals = iVal.toDoubleListRef(); + t.reset(glow::ElemKind::FloatTy, {vals.size()}); + auto handle = t.getHandle(); + for (size_t i = 0; i < vals.size(); ++i) { + handle.raw(i) = to32Bit(vals[i]); + } + } else if (iVal.isBool()) { + bool val = iVal.toBool(); + t.reset(glow::ElemKind::BoolTy, {1}); + t.getHandle().raw(0) = val; + } else if (iVal.isBoolList()) { + const auto &valsList = iVal.toBoolList(); + const auto &vals = c10::impl::toVector(valsList); + t.reset(glow::ElemKind::BoolTy, {vals.size()}); + auto handle = t.getHandle(); + for (size_t i = 0; i < vals.size(); ++i) { + handle.raw(i) = vals[i]; + } + } else if (iVal.isTensor()) { + assert(false && "not yet implemented tensor constants"); + } else if (iVal.isNone()) { + // Nothing to do for None + return; + } else { + assert(false && "Unsupported constant type"); } + + glow::Constant *glowConstant = mod->createConstant("constant", std::move(t)); + addGlowNodeValue(outputs[0], glowConstant->getOutput()); +} + +void PyTorchModelLoader::loadNode(const torch::jit::Node *ptNode) { + switch (ptNode->kind()) { + case at::prim::Constant: + return loadConstant(ptNode); + case at::aten::mul: + return loadMul(ptNode); + case at::aten::div: + return loadDiv(ptNode); + case at::aten::add: + return loadAdd(ptNode); + case at::aten::sub: + return loadSub(ptNode); + case at::aten::_convolution: + return loadConvolution(ptNode); + case at::aten::relu: + return loadRelu(ptNode); default: assert(false && "unrecognized node"); } } +// static +bool PyTorchModelLoader::isNodeSupported(const torch::jit::Node *node) { + // NOTE: Even though importing at::prim::Constant nodes is supported, don't + // add it here so that only the Constant nodes that are strictly necessary + // for values in the subgraph will be imported. + switch (node->kind()) { + case at::aten::mul: + case at::aten::div: + case at::aten::add: + case at::aten::sub: + case at::aten::relu: + case at::aten::_convolution: + return true; + default: + return false; + } + return false; +} + void PyTorchModelLoader::load() { assert(f_ == nullptr && "This model loader has already been used."); @@ -90,7 +403,7 @@ void PyTorchModelLoader::load() { } for (torch::jit::Value *output : subgraph_->outputs()) { - glow::NodeValue outputNodeValue = valueMap_.at(output); + glow::NodeValue outputNodeValue = getGlowNodeValue(output); auto *save = f_->createSave("save", outputNodeValue); outputPlaceholders_.push_back(save->getPlaceholder()); } @@ -100,15 +413,3 @@ PyTorchModelLoader::PyTorchModelLoader(glow::Module &mod, torch::jit::Graph *subgraph, at::ArrayRef &inputs) : mod_(mod), subgraph_(subgraph), inputs_(inputs) {} - -// static -bool PyTorchModelLoader::isNodeSupported(const torch::jit::Node *node) { - switch (node->kind()) { - case at::aten::mul: - case at::aten::div: - return true; - default: - return false; - } - return false; -} diff --git a/torch_glow/src/PyTorchModelLoader.h b/torch_glow/src/PyTorchModelLoader.h index f73a85bb6d..22cbb3da60 100644 --- a/torch_glow/src/PyTorchModelLoader.h +++ b/torch_glow/src/PyTorchModelLoader.h @@ -22,16 +22,28 @@ #include "glow/Graph/Graph.h" +/// Loads PyTorch JIT IR subgraphs as a Glow Function. class PyTorchModelLoader { + /// Glow module in which the loaded Glow Function will be created. glow::Module &mod_; + /// The PyTorch subgraph being loaded. torch::jit::Graph *subgraph_; + + /// The reference PyTorch inputs used for loading. This is required for shape + /// information. at::ArrayRef &inputs_; + /// The Glow function that will created. glow::Function *f_ = nullptr; + + /// Glow Placeholders corresponding to stack inputs from PyTorch. std::vector inputPlaceholders_; + + /// Glow Placeholders corresponding to stack outputs to back PyTorch. std::vector outputPlaceholders_; + /// Mapping from PyTorch Values to Glow NodeValues created during loading. std::unordered_map valueMap_; public: @@ -47,21 +59,74 @@ class PyTorchModelLoader { void load(); /// Returns whether or not a PyTorch node is supported. + /// NOTE: For now this is just an enumeration of all type of PyTorch nodes + /// that the loader knows about but doesn't really guarantee that loading will + /// will succeed because determining this requires more informations such as + /// shape info that isn't yet available when this is run. static bool isNodeSupported(const torch::jit::Node *node); + /// Get the Glow function that this loader has created. glow::Function *getFunction() { return f_; } + /// \returns the Glow input placeholders for the loaded function in the order + /// that is expected by the stack of PyTorch inputs when running the function. const std::vector &getInputPlaceholders() { return inputPlaceholders_; } + /// \returns the Glow output placeholders for the loaded function in the order + /// that is expected by the stack of PyTorch inputs when running the function. const std::vector &getOutputPlaceholders() { return outputPlaceholders_; } private: - glow::Placeholder *loadValue(const torch::jit::Value *val); + /// Find the Glow NodeValue that maps to a given PyTorch value \p value. + glow::NodeValue getGlowNodeValue(const torch::jit::Value *value) const; + + /// Returns true if a Glow NodeValue has been created for a given PyTorch + /// Value \p value. + bool hasGlowNodeValue(const torch::jit::Value *value) const; + + /// Add a new mapping from the PyTorch Value \p value to the Glow NodeValue + /// \nodeValue. + void addGlowNodeValue(const torch::jit::Value *value, + glow::NodeValue nodeValue); + + /// Given a PyTorch Value \p value, returns a handle to the tensor backning + /// the glow Constant that is mapped to this PyTorch Value. This requires that + /// the a mapping from the given Value to a Glow NodeValue has already been + /// created and that the NodeValue is the output of a Glow ConstantNode. + template + glow::Handle getGlowConstantHandle(const torch::jit::Value *value) const; + + /// Creates and \returns a new Glow Placeholder corresponding to the given + /// PyTorch Value \p value. + glow::Placeholder *loadValue(const torch::jit::Value *value); + + /// Load a given PyTorch Node \p ptNode. void loadNode(const torch::jit::Node *ptNode); + + /// Load a PyTorch Constant node as a Glow Constant. + void loadConstant(const torch::jit::Node *ptNode); + + /// Load a PyTorch mul node. + void loadMul(const torch::jit::Node *ptNode); + + /// Load a PyTorch div node. + void loadDiv(const torch::jit::Node *ptNode); + + /// Load a PyTorch add node. + void loadAdd(const torch::jit::Node *ptNode); + + /// Load a PyTorch sub node. + void loadSub(const torch::jit::Node *ptNode); + + /// Load a PyTorch relu node. + void loadRelu(const torch::jit::Node *ptNode); + + /// Load a PyTorch _convolution node. + void loadConvolution(const torch::jit::Node *ptNode); }; #endif // GLOW_IMPORTER_PYTORCH_PYTORCHMODELLOADER_H diff --git a/torch_glow/src/binding.cpp b/torch_glow/src/binding.cpp index 6243f129c4..f2254a4280 100644 --- a/torch_glow/src/binding.cpp +++ b/torch_glow/src/binding.cpp @@ -17,6 +17,7 @@ #include #include +#include #include #include @@ -27,17 +28,24 @@ namespace py = pybind11; +/// The PyTorch symbol used to identify the Node that contains PyTorch subgraphs +/// that are compiled for running on Glow. static auto glowSymbol = at::Symbol::fromQualString("glow::CompilationGroup"); + /// Whether or not run the custom pass that fuses jit nodes into a glow node. static bool fusionPassEnabled = false; +/// Manages a CachingGraphRunner singleton. CachingGraphRunner *getGraphRunner() { static auto runner_ = std::unique_ptr(new CachingGraphRunner()); return runner_.get(); } +/// Enable compiling PyTorch subgraphs to Glow Functions. void enableFusionPass() { fusionPassEnabled = true; } + +/// Disable compiling PyTorch subgraphs to Glow Functions. void disableFusionPass() { fusionPassEnabled = false; } /// Register the glow::CompilationGroup operator. @@ -48,7 +56,6 @@ void registerGlowOp() { torch::jit::RegisterOperators op( {torch::jit::Operator(glowSymbol, [](const torch::jit::Node *node) { - getGraphRunner()->addGraph(node); return [node](torch::jit::Stack &stack) { getGraphRunner()->runGraph(node, stack); return 0; @@ -62,11 +69,13 @@ void registerGlowOp() { void registerPass() { torch::jit::RegisterPass pass([](std::shared_ptr &g) { if (fusionPassEnabled) { - CustomFuseGraph(g, PyTorchModelLoader::isNodeSupported, glowSymbol); + torch::jit::CustomFuseGraph(g, PyTorchModelLoader::isNodeSupported, + glowSymbol); } }); } +/// The torch_glow pybind11 module. PYBIND11_MODULE(_torch_glow, m) { registerGlowOp(); registerPass(); diff --git a/torch_glow/tests/__init__.py b/torch_glow/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torch_glow/tests/basic_test.py b/torch_glow/tests/basic_test.py deleted file mode 100644 index df5c8bd981..0000000000 --- a/torch_glow/tests/basic_test.py +++ /dev/null @@ -1,26 +0,0 @@ -import torch -import torch_glow - -@torch.jit.script -def foo(a, b): - c = a.mul(b) - a = c.mul(c) - a = c.mul(a) - d = c.div(a) - return d - -torch_glow.enableFusionPass() - -@torch.jit.script -def foo_glow(a, b): - return foo(a, b) - -def test_foo(): - x = torch.randn(4) - y = torch.randn(4) - - jit_res = foo(x, y) - jit_glow_res = foo_glow(x, y) - - assert torch.allclose(jit_res, jit_glow_res) - \ No newline at end of file diff --git a/torch_glow/tests/nodes/add_test.py b/torch_glow/tests/nodes/add_test.py new file mode 100644 index 0000000000..c67c938fa4 --- /dev/null +++ b/torch_glow/tests/nodes/add_test.py @@ -0,0 +1,16 @@ +import torch +import torch_glow + +from tests.utils import jitVsGlow + +# Basic test of the PyTorch add Node on Glow. +def test_add_basic(): + + def add_basic(a, b): + c = a.add(b) + return c.add(c) + + x = torch.randn(4) + y = torch.randn(4) + + jitVsGlow(add_basic, x, y) \ No newline at end of file diff --git a/torch_glow/tests/nodes/conv2d_test.py b/torch_glow/tests/nodes/conv2d_test.py new file mode 100644 index 0000000000..790e0407f7 --- /dev/null +++ b/torch_glow/tests/nodes/conv2d_test.py @@ -0,0 +1,54 @@ +import torch +import torch.nn.functional as F +import torch_glow +from collections import namedtuple + +from tests.utils import jitVsGlow + +# Basic test of the PyTorch conv2d Node on Glow. +def test_conv2d_basic(): + + def conv2d_basic(inputs, filters): + conv = F.conv2d(inputs, filters, padding=1) + return F.relu(conv) + + inputs = torch.randn(1, 4, 5, 5) + filters = torch.randn(8, 4, 3, 3) + + jitVsGlow(conv2d_basic, inputs, filters) + +# Test of the PyTorch conv2d Node with a provided bias tensor. +def test_conv2d_with_bias(): + + def conv2d_with_bias(inputs, filters, bias): + conv = F.conv2d(inputs, filters, bias) + return F.relu(conv) + + inputs = torch.randn(1, 4, 5, 5) + filters = torch.randn(8, 4, 3, 3) + bias = torch.randn(8) + + jitVsGlow(conv2d_with_bias, inputs, filters, bias) + +# Test of the PyTorch conv2d Node sweeping through various parameters of the +# Node to test that they work correctly. +def test_conv2d_param_sweep(): + hwOpts = [3, 4] + padOpts = [0, 1] + groupsOpts = [1, 2] + dilationOpts = [1, 2] + strideOpts = [1, 2] + + Setting = namedtuple('Setting', ['h', 'w', 'p', 'g', 'd', 's',]) + + settings = [Setting(h=h, w=w, p=p, g=g, d=d, s=s) for h in hwOpts for w in hwOpts for p in padOpts for g in groupsOpts for d in dilationOpts for s in strideOpts] + + for setting in settings: + def conv2d_param_sweep(inputs, filters): + conv = F.conv2d(inputs, filters, padding=setting.p, groups=setting.g) + return F.relu(conv) + + inputs = torch.randn(2, 4, setting.h, setting.w) + filters = torch.randn(8, 4/setting.g, 3, 3) + + jitVsGlow(conv2d_param_sweep, inputs, filters) diff --git a/torch_glow/tests/nodes/div_test.py b/torch_glow/tests/nodes/div_test.py new file mode 100644 index 0000000000..95666b541b --- /dev/null +++ b/torch_glow/tests/nodes/div_test.py @@ -0,0 +1,16 @@ +import torch +import torch_glow + +from tests.utils import jitVsGlow + +# Basic test of the PyTorch div Node on Glow. +def test_div_basic(): + + def div_basic(a, b): + c = a.div(b) + return c.div(c) + + x = torch.randn(4) + y = torch.randn(4) + + jitVsGlow(div_basic, x, y) \ No newline at end of file diff --git a/torch_glow/tests/nodes/mul_test.py b/torch_glow/tests/nodes/mul_test.py new file mode 100644 index 0000000000..5b1ff80d72 --- /dev/null +++ b/torch_glow/tests/nodes/mul_test.py @@ -0,0 +1,16 @@ +import torch +import torch_glow + +from tests.utils import jitVsGlow + +# Basic test of the PyTorch mul Node on Glow. +def test_mul_basic(): + + def mul_basic(a, b): + c = a.mul(b) + return c.mul(c) + + x = torch.randn(4) + y = torch.randn(4) + + jitVsGlow(mul_basic, x, y) \ No newline at end of file diff --git a/torch_glow/tests/nodes/relu_test.py b/torch_glow/tests/nodes/relu_test.py new file mode 100644 index 0000000000..b30c18a6ad --- /dev/null +++ b/torch_glow/tests/nodes/relu_test.py @@ -0,0 +1,18 @@ +import torch +import torch.nn.functional as F +import torch_glow + +from tests.utils import jitVsGlow + +# Basic test of the PyTorch relu Node on Glow. +def test_relu_basic(): + + def relu_basic(a): + b = F.relu(a) + return F.relu(b) + + x = torch.randn(4) + # make sure we have at least one negative + x[0] = -2.0 + + jitVsGlow(relu_basic, x) \ No newline at end of file diff --git a/torch_glow/tests/nodes/sub_test.py b/torch_glow/tests/nodes/sub_test.py new file mode 100644 index 0000000000..3b4308ee39 --- /dev/null +++ b/torch_glow/tests/nodes/sub_test.py @@ -0,0 +1,16 @@ +import torch +import torch_glow + +from tests.utils import jitVsGlow + +# Basic test of the PyTorch sub Node on Glow. +def test_sub_basic(): + + def sub_basic(a, b): + c = a.sub(b) + return c.sub(c) + + x = torch.randn(4) + y = torch.randn(4) + + jitVsGlow(sub_basic, x, y) \ No newline at end of file diff --git a/torch_glow/tests/utils.py b/torch_glow/tests/utils.py new file mode 100644 index 0000000000..27898deea4 --- /dev/null +++ b/torch_glow/tests/utils.py @@ -0,0 +1,30 @@ +import torch +import torch_glow + +GLOW_NODE_NAME = "glow::CompilationGroup" + +# Runs the given inputs \p *inputs on \p f both with and without lowering \p f +# to Glow and compares the results. +def jitVsGlow(f, *inputs): + torch_glow.disableFusionPass() + torch_trace = torch.jit.trace(f, inputs) + torch_res = torch_trace(*inputs) + + torch_glow.enableFusionPass() + glow_trace = torch.jit.trace(f, inputs) + glow_res = glow_trace(*inputs) + + # check that there are no Glow nodes in the torch graph + torch_graph = torch_trace.graph_for(*inputs) + print("torch_graph,", torch_graph) + num_glow_nodes = len(torch_graph.findAllNodes(GLOW_NODE_NAME)) + assert num_glow_nodes == 0, "Expected no Glow nodes, found {}".format(num_glow_nodes) + + # check that there is exactly 1 Glow node in the glow graph + glow_graph = glow_trace.graph_for(*inputs) + print("glow_graph,", glow_graph) + num_glow_nodes = len(glow_graph.findAllNodes(GLOW_NODE_NAME)) + assert num_glow_nodes == 1, "Expected exactly 1 Glow node, found {}".format(num_glow_nodes) + + + assert torch.allclose(torch_res, glow_res, atol=01e-6) \ No newline at end of file