Skip to content

Commit

Permalink
[RFC] Modularize functions of parsing bytecode (#61862)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #61862

Modularize functions of parsing bytecode tables so that they can be used as needed in situations other than mobile lite interpreter.
* The decoupled functions are re-used by current lite interpreter loader.
* The bytecode can be serialized/deserialized from other formats.
* The decoupled functions have minimum dependencies on other PyTorch components.

Next:
Build a driver binary to include the parser and interpreter, but only has necessary dependency on other PyTorch components.
ghstack-source-id: 137867287

Test Plan:
As an example, a simple bytecode is parsed to a mobile function, and directly run in the added unit test, `RunTimeTest:ParseBytecode`. It contains basic control flow (if, else) and basic data orchestration (list construction).
CI

Reviewed By: larryliu0820

Differential Revision: D29798382

Pulled By: iseeyuan

fbshipit-source-id: 1c173a5f5d37097e3a97baec3f3e48e1eea1400f
  • Loading branch information
iseeyuan authored and facebook-github-bot committed Sep 12, 2021
1 parent dd2d48d commit 30a7c76
Show file tree
Hide file tree
Showing 12 changed files with 342 additions and 171 deletions.
3 changes: 2 additions & 1 deletion caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -536,10 +536,11 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/mobile/function.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/import.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/import_data.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/interpreter.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/model_compatibility.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/module.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/observer.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/interpreter.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/parse_bytecode.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/train/export_data.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/train/optim/sgd.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/train/random.cpp
Expand Down
70 changes: 70 additions & 0 deletions test/cpp/jit/test_lite_interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/mobile/model_compatibility.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/mobile/parse_bytecode.h>
#include <torch/csrc/jit/mobile/runtime_compatibility.h>
#include <torch/csrc/jit/serialization/export.h>
#include <torch/csrc/jit/serialization/import.h>
Expand Down Expand Up @@ -886,6 +887,75 @@ TEST(LiteInterpreterTest, DefaultArgsConv) {
AT_ASSERT(output.equal(outputref));
}

TEST(RunTimeTest, ParseBytecode) {
// A simple example to show a simple bytecode that can be used independent of
// PyTorch TorchScript serialization (unpickler, etc) and operator library.
// It has basic control flow (if, else) and basic data orchestration (list
// construction). The original PyTorch program:

// class Module(torch.nn.Module):
//
// def __init__(self):
// super().__init__()
//
// def forward(self, x: int, h: int, xfirst: bool):
// if xfirst:
// return [x, h]
// else:
// return [h, x]

// 1. Prepare for the bytecode. In reality it can be from a customized
// deserializer.
std::vector<IValue> instructions{
to_tuple({"STOREN", 1, 4}),
to_tuple({"DROPR", 1, 0}),
to_tuple({"MOVE", 4, 0}),
to_tuple({"JF", 5, 0}),
to_tuple({"LOAD", 2, 0}),
to_tuple({"LOAD", 3, 0}),
to_tuple({"LIST_CONSTRUCT", 0, 2}),
to_tuple({"JMP", 4, 0}),
to_tuple({"LOAD", 3, 0}),
to_tuple({"LOAD", 2, 0}),
to_tuple({"LIST_CONSTRUCT", 1, 2}),
to_tuple({"STORE", 5, 0}),
to_tuple({"DROPR", 3, 0}),
to_tuple({"DROPR", 2, 0}),
to_tuple({"MOVE", 5, 0}),
to_tuple({"RET", 0, 0}),
};
std::vector<IValue> operators; // empty for this example
std::vector<IValue> constants; // empty for this example

std::vector<IValue> types{"List[int]", "List[int]"};
// 2. Parse the function
std::string function_name("test_function");
auto function = std::unique_ptr<mobile::Function>(
new mobile::Function(c10::QualifiedName(function_name)));
std::vector<IValue> debug_handles_m_tuple;
parseInstructions(
function_name, instructions, debug_handles_m_tuple, function.get());
parseTypes(types, function.get());
const size_t rsize = 5;
parseRegisterSize(rsize, function.get());

// 3. Prepare for inputs and run the function
// Note that the first input is reserved for Module object.
// Since this is a function test and Module object is not required,
// a dummy IValue (0) is added here.
std::vector<IValue> inputs{0, 1, 2, true};
function->run(inputs);
auto output = inputs[0].toList();
ASSERT_EQ(output[0], 1);
ASSERT_EQ(output[1], 2);

std::vector<IValue> inputs1{0, 1, 2, false};
function->run(inputs1);
auto output1 = inputs1[0].toList();
ASSERT_EQ(output1[0], 2);
ASSERT_EQ(output1[1], 1);
}

namespace {
void testLiteModuleCompareResultTensors(
Module& m,
Expand Down
2 changes: 2 additions & 0 deletions tools/build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ torch_mobile_core = [
"torch/csrc/jit/mobile/model_compatibility.cpp",
"torch/csrc/jit/mobile/module.cpp",
"torch/csrc/jit/mobile/observer.cpp",
"torch/csrc/jit/mobile/parse_bytecode.cpp",
"torch/csrc/jit/runtime/register_prim_ops.cpp",
"torch/csrc/jit/runtime/register_special_ops.cpp",
]
Expand Down Expand Up @@ -474,6 +475,7 @@ libtorch_extra_sources = libtorch_core_jit_sources + [
"torch/csrc/jit/mobile/model_compatibility.cpp",
"torch/csrc/jit/mobile/module.cpp",
"torch/csrc/jit/mobile/observer.cpp",
"torch/csrc/jit/mobile/parse_bytecode.cpp",
"torch/csrc/jit/mobile/train/export_data.cpp",
"torch/csrc/jit/mobile/train/optim/sgd.cpp",
"torch/csrc/jit/mobile/train/random.cpp",
Expand Down
1 change: 0 additions & 1 deletion torch/csrc/jit/mobile/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include <torch/csrc/jit/mobile/interpreter.h>
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/custom_class_detail.h>

namespace torch {
namespace jit {
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/mobile/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ class Function {
using OperatorCacheType =
std::unordered_map<c10::OperatorName, OperatorFunctionWithSchema>;

Function(c10::QualifiedName name);
bool run(Stack& stack) const;
TORCH_API Function(c10::QualifiedName name);
TORCH_API bool run(Stack& stack) const;
c10::IValue operator()(Stack& stack) const;
const std::string& name() const;
TORCH_API const c10::QualifiedName& qualname() const;
Expand Down
Loading

0 comments on commit 30a7c76

Please sign in to comment.