Skip to content

Commit

Permalink
[PyTorchEdge] backport v8 to v7 to support promoted ops as instruction (
Browse files Browse the repository at this point in the history
#71662)

Summary:
Pull Request resolved: #71662

backport v8 to v7 to support promoted ops as instruction

a flag to help export as instruction from v8 and export as operators for v7 and below

Test Plan:
```
buck test caffe2/test/cpp/jit:jit -- LiteInterpreterTest.BackPortByteCodeModelAllVersions

Started reporting to test run: https://www.internalfb.com/intern/testinfra/testrun/5629499620570927
    ✓ ListingSuccess: caffe2/test/cpp/jit:jit : 461 tests discovered (15.693)
    ✓ Pass: caffe2/test/cpp/jit:jit - LiteInterpreterTest.BackPortByteCodeModelAllVersions (2.712)
Summary
  Pass: 1
  ListingSuccess: 1
If you need help understanding your runs, please follow the wiki: https://fburl.com/posting_in_tpx_users
Finished test run: https://www.internalfb.com/intern/testinfra/testrun/5629499620570927
```

Differential Revision: D33719098

fbshipit-source-id: 4f994c7e196a5838ea9376067a0233c0bf52de9c
  • Loading branch information
pavithranrao authored and facebook-github-bot committed Jan 27, 2022
1 parent fdec945 commit f8ff95f
Show file tree
Hide file tree
Showing 11 changed files with 146 additions and 33 deletions.
31 changes: 18 additions & 13 deletions caffe2/serialize/versions.h
Expand Up @@ -91,22 +91,27 @@ constexpr uint64_t kMinProducedFileFormatVersion = 0x3L;
// 0x2L: (Comment missing)
// 0x3L: (Comment missing)
// 0x4L: (update) Added schema to function tuple. Forward-compatible change.
// 0x5L: (update) Update bytecode is sharing constant tensor files from torchscript, and only serialize
// extra tensors that are not in the torchscript constant table. Also update tensor storage schema adapting
// to the unify format, the root key of tensor storage is updated from {index} to
// {the_pointer_value_the_tensor.storage}, for example: `140245072983168.storage`
// Forward-compatibility change.
// 0x6L: Implicit opereator versioning using number of specified argument.
// Refer to the summary of https://github.com/pytorch/pytorch/pull/56845
// for details.
// 0x7L: Enable support for operators with default arguments plus out arguments.
constexpr uint64_t kProducedBytecodeVersion = 0x7L;
// 0x5L: (update) Update bytecode is sharing constant tensor files from
// torchscript, and only serialize extra tensors that are not in the
// torchscript constant table. Also update tensor storage schema adapting to
// the unify format, the root key of tensor storage is updated from {index} to
// {the_pointer_value_the_tensor.storage}, for example:
// `140245072983168.storage` Forward-compatibility change. 0x6L: Implicit
// opereator versioning using number of specified argument. Refer to the
// summary of https://github.com/pytorch/pytorch/pull/56845 for details. 0x7L:
// Enable support for operators with default arguments plus out arguments.
// 0x8L: Emit promoted operators as instructions
constexpr uint64_t kProducedBytecodeVersion = 0x8L;

static_assert(
kProducedBytecodeVersion >= kProducedFileFormatVersion,
"kProducedBytecodeVersion must be higher or equal to kProducedFileFormatVersion.");

// Introduce kMinSupportedBytecodeVersion and kMaxSupportedBytecodeVersion
// for limited backward/forward compatibility support of bytecode. If
// kMinSupportedBytecodeVersion <= model_version <= kMaxSupportedBytecodeVersion (in loader),
// we should support this model_version. For example, we provide a wrapper to
// handle an updated operator.
// kMinSupportedBytecodeVersion <= model_version <= kMaxSupportedBytecodeVersion
// (in loader), we should support this model_version. For example, we provide a
// wrapper to handle an updated operator.
constexpr uint64_t kMinSupportedBytecodeVersion = 0x3L;
constexpr uint64_t kMaxSupportedBytecodeVersion = 0x8L;

Expand Down
63 changes: 52 additions & 11 deletions test/cpp/jit/test_lite_interpreter.cpp
Expand Up @@ -571,19 +571,34 @@ namespace {

void compareModelOutput(
c10::ArrayRef<IValue> actual_result_list,
const std::vector<Tensor>& expect_result_list) {
const std::vector<IValue>& expect_result_list) {
AT_ASSERT(actual_result_list.size() == expect_result_list.size());
AT_ASSERT(actual_result_list[0].toTensor().equal(expect_result_list[0]));
AT_ASSERT(
actual_result_list[1].toTensor().dim() == expect_result_list[1].dim());
AT_ASSERT(actual_result_list[2].toTensor().equal(expect_result_list[2]));
AT_ASSERT(actual_result_list[3].toTensor().equal(expect_result_list[3]));
actual_result_list[0].toTensor().equal(expect_result_list[0].toTensor()));
AT_ASSERT(
actual_result_list[1].toTensor().dim() ==
expect_result_list[1].toTensor().dim());
AT_ASSERT(
actual_result_list[2].toTensor().equal(expect_result_list[2].toTensor()));
AT_ASSERT(
actual_result_list[3].toTensor().equal(expect_result_list[3].toTensor()));
ASSERT_EQ(
actual_result_list[4].toStringRef(), expect_result_list[4].toStringRef());
ASSERT_EQ(actual_result_list[5].toBool(), expect_result_list[5].toBool());
ASSERT_EQ(actual_result_list[6].toBool(), expect_result_list[6].toBool());
ASSERT_EQ(actual_result_list[7].toBool(), expect_result_list[7].toBool());
AT_ASSERT(
actual_result_list[8].toTensor().equal(expect_result_list[8].toTensor()));
ASSERT_EQ(
actual_result_list[9].toStringRef(), expect_result_list[9].toStringRef());
ASSERT_EQ(actual_result_list[10].toInt(), expect_result_list[10].toInt());
ASSERT_EQ(actual_result_list[11].toBool(), expect_result_list[11].toBool());
}

void runAndCheckTorchScriptModel(
std::stringstream& input_model_stream,
const std::vector<IValue>& input_data,
const std::vector<Tensor>& expect_result_list,
const std::vector<IValue>& expect_result_list,
const int64_t expect_version) {
auto actual_version = _get_model_bytecode_version(input_model_stream);
AT_ASSERT(actual_version == expect_version);
Expand All @@ -600,7 +615,7 @@ void runAndCheckTorchScriptModel(
void runAndCheckBytecodeModel(
std::stringstream& input_model_stream,
const std::vector<IValue>& input_data,
const std::vector<Tensor>& expect_result_list,
const std::vector<IValue>& expect_result_list,
const int64_t expect_version) {
auto actual_version = _get_model_bytecode_version(input_model_stream);
AT_ASSERT(actual_version == expect_version);
Expand All @@ -618,7 +633,7 @@ void runAndCheckBytecodeModel(
void backportAllVersionCheck(
std::stringstream& test_model_file_stream,
std::vector<IValue>& input_data,
std::vector<Tensor>& expect_result_list,
std::vector<IValue>& expect_result_list,
const int64_t expect_from_version) {
auto from_version = _get_model_bytecode_version(test_model_file_stream);
AT_ASSERT(from_version == expect_from_version);
Expand Down Expand Up @@ -668,6 +683,9 @@ TEST(LiteInterpreterTest, BackPortByteCodeModelAllVersions) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
module.register_parameter("bias", torch::ones({20}), false);
module.define(R"(
def fn(self, x:float=1.0):
return x
def forward(self, input):
x1 = torch.zeros(2, 2)
x2 = torch.empty_like(torch.empty(2, 2))
Expand All @@ -677,21 +695,44 @@ TEST(LiteInterpreterTest, BackPortByteCodeModelAllVersions) {
x = 2 * torch.ones(1)
h = torch.ones(1)
torch.add(x, h, out=x)
return (x1, x2, x3, x)
)");
device = torch.ones(1, 1).cpu().device.type
is_cuda = x1.is_cuda
bool_val = True
check_is = [] is None
check_is_not = [1] is not None
check_not = not bool_val
num_to_tensor = torch.tensor([self.fn()])
d = {"a": "abc"}
check_dict_index = d["a"]
check_dim = x1.dim()
return (
x1, x2, x3, x, device, is_cuda, check_is,
check_is_not, num_to_tensor, check_dict_index,
check_dim, check_not
)
)");

torch::jit::Module module_freeze = freeze(module);

std::stringstream input_model_stream;
module_freeze._save_for_mobile(input_model_stream);
std::vector<IValue> input_data =
std::vector<IValue>({torch::ones({1, 1, 28, 28})});
std::vector<Tensor> expect_result_list;
std::vector<IValue> expect_result_list;
expect_result_list.emplace_back(at::ones({2, 2}, ScalarType::Float) * 0);
expect_result_list.emplace_back(at::ones({2, 2}, ScalarType::Float));
expect_result_list.emplace_back(
at::ones({1, 20, 24, 24}, ScalarType::Float) * 26);
expect_result_list.emplace_back(3 * at::ones({1}));
// "cpu" False, False, True, tensor(1), "abc", 2, False)
expect_result_list.emplace_back(c10::IValue("cpu"));
expect_result_list.emplace_back(c10::IValue(false));
expect_result_list.emplace_back(c10::IValue(false));
expect_result_list.emplace_back(c10::IValue(true));
expect_result_list.emplace_back(c10::IValue(at::ones({1})));
expect_result_list.emplace_back(c10::IValue("abc"));
expect_result_list.emplace_back(c10::IValue(2));
expect_result_list.emplace_back(c10::IValue(false));

backportAllVersionCheck(
input_model_stream,
Expand Down
4 changes: 2 additions & 2 deletions test/test_mobile_optimizer.py
Expand Up @@ -151,7 +151,7 @@ def forward(self, x):
bn_scripted_module = torch.jit.script(bn_test_module)
bn_scripted_module.eval()

self.assertEqual(len(torch.jit.export_opnames(bn_scripted_module)), 14)
self.assertEqual(len(torch.jit.export_opnames(bn_scripted_module)), 11)
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
.run(str(get_forward(bn_scripted_module._c).graph))

Expand Down Expand Up @@ -252,7 +252,7 @@ def foo(self, x):
bn_no_forward_scripted_module = torch.jit.script(bn_test_no_forward_module)
bn_no_forward_scripted_module.eval()

self.assertEqual(len(torch.jit.export_opnames(bn_no_forward_scripted_module)), 14)
self.assertEqual(len(torch.jit.export_opnames(bn_no_forward_scripted_module)), 11)
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
.run(bn_no_forward_scripted_module.foo.graph)

Expand Down
41 changes: 39 additions & 2 deletions torch/csrc/jit/mobile/compatibility/backport_manager.cpp
Expand Up @@ -27,6 +27,7 @@ constexpr int64_t kBytecodeVersionV4 = 0x4L;
constexpr int64_t kBytecodeVersionV5 = 0x5L;
constexpr int64_t kBytecodeVersionV6 = 0x6L;
constexpr int64_t kBytecodeVersionV7 = 0x7L;
constexpr int64_t kBytecodeVersionV8 = 0x8L;
} // namespace

/********************** Utility Functions **********************/
Expand Down Expand Up @@ -434,7 +435,8 @@ std::stringstream backport_v6_to_v5(std::stringstream& input_model_stream) {
{
BytecodeEmitModeGuard argNumGuard(
true /*emit_default_input_instructions*/,
false /*enable_defaults_args_with_out_args*/);
false /*enable_defaults_args_with_out_args*/,
false /*enable_emit_promoted_ops*/);
torch_script._save_for_mobile(
intermediate_model_stream, extra_files, hasBytecodeDebug);
}
Expand Down Expand Up @@ -501,7 +503,8 @@ std::stringstream backport_v7_to_v6(std::stringstream& input_model_stream) {
{
BytecodeEmitModeGuard argNumGuard(
false /*emit_default_input_instructions*/,
false /*enable_defaults_args_with_out_args*/);
false /*enable_defaults_args_with_out_args*/,
false /*enable_emit_promoted_ops*/);
torch_script._save_for_mobile(
intermediate_model_stream, extra_files, hasBytecodeDebug);
}
Expand All @@ -512,6 +515,39 @@ std::stringstream backport_v7_to_v6(std::stringstream& input_model_stream) {
return output_model_stream;
}

std::stringstream backport_v8_to_v7(std::stringstream& input_model_stream) {
std::shared_ptr<IStreamAdapter> rai =
std::make_shared<IStreamAdapter>(&input_model_stream);
auto reader = std::make_shared<PyTorchStreamReader>(rai);
// extra_files are kept
auto records = reader->getAllRecords();
bool hasBytecodeDebug = reader->hasRecord("mobile_debug_handles.pkl");
ExtraFilesMap extra_files;
for (const auto& record : records) {
std::size_t found = record.find_last_of("/\\");
auto path = record.substr(0, found);
if ("extra" == path) {
extra_files.emplace(record.substr(found + 1), "");
}
}
Module torch_script = torch::jit::load(rai, c10::nullopt, extra_files);
std::stringstream intermediate_model_stream;
{
BytecodeEmitModeGuard argNumGuard(
false /*emit_default_input_instructions*/,
true /*enable_defaults_args_with_out_args*/,
false /*enable_emit_promoted_ops*/);
torch_script._save_for_mobile(
intermediate_model_stream, extra_files, hasBytecodeDebug);
}

// Update the bytecode version (from 8 to 7)
std::stringstream output_model_stream =
update_bytecode_version(intermediate_model_stream, kBytecodeVersionV7);

return output_model_stream;
}

} // namespace

/********************** BackportManager **********************/
Expand All @@ -528,6 +564,7 @@ BackportManager::BackportManager() {
registerBytecodeBackportFunction(kBytecodeVersionV5, backport_v5_to_v4);
registerBytecodeBackportFunction(kBytecodeVersionV6, backport_v6_to_v5);
registerBytecodeBackportFunction(kBytecodeVersionV7, backport_v7_to_v6);
registerBytecodeBackportFunction(kBytecodeVersionV8, backport_v8_to_v7);
}

std::unordered_map<
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/runtime/interpreter.cpp
Expand Up @@ -1059,12 +1059,14 @@ MobileCode::MobileCode(
std::string function_name,
bool emit_default_input_instructions,
bool support_default_args_before_out,
bool emit_promoted_ops,
size_t remaining_bailout_depth)
: Code(new interpreter::MobileCodeImpl(
graph,
std::move(function_name),
emit_default_input_instructions,
support_default_args_before_out,
emit_promoted_ops,
remaining_bailout_depth)) {}

MobileCode::~MobileCode() = default;
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/runtime/interpreter.h
Expand Up @@ -88,6 +88,7 @@ struct TORCH_API MobileCode : Code {
std::string function_name,
bool emit_default_input_instructions = true,
bool support_default_args_before_out = true,
bool emit_promoted_ops = true,
size_t remaining_bailout_depth = 0);
~MobileCode();
};
Expand Down
7 changes: 5 additions & 2 deletions torch/csrc/jit/runtime/interpreter/code_impl.h
Expand Up @@ -866,10 +866,12 @@ struct MobileCodeImpl : CodeImpl {
std::string function_name,
bool emit_default_input_instructions,
bool support_default_args_before_out,
bool emit_promoted_ops,
size_t remaining_bailout_depth)
: CodeImpl(graph, function_name, remaining_bailout_depth, false),
emit_default_input_instructions_(emit_default_input_instructions),
support_default_args_before_out_(support_default_args_before_out) {
support_default_args_before_out_(support_default_args_before_out),
emit_promoted_ops_(emit_promoted_ops) {
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
run();
}
Expand Down Expand Up @@ -962,7 +964,6 @@ struct MobileCodeImpl : CodeImpl {
int64_t X = 0,
uint64_t N = 0,
bool emit_inputs = true) override {
bool emit_promoted_ops_ = false;
if (emit_promoted_ops_) {
CodeImpl::emitOperatorOrInstruction(node, op, X, N, emit_inputs);
} else {
Expand All @@ -974,6 +975,8 @@ struct MobileCodeImpl : CodeImpl {
bool emit_default_input_instructions_;
// To support forward compatibility for bytecode version bump from v6 to v7
bool support_default_args_before_out_;
// To support forward compatibility for bytecode version bump from v7 to v8
bool emit_promoted_ops_;
};

} // namespace interpreter
Expand Down
15 changes: 13 additions & 2 deletions torch/csrc/jit/serialization/export.h
Expand Up @@ -201,6 +201,9 @@ struct TORCH_API BytecodeEmitMode {

static bool is_default_args_before_out_args_enabled();
static void set_default_args_before_out_args_enabled(bool enabled);

static bool is_emit_promoted_ops_enabled();
static void set_default_emit_promoted_ops_enabled(bool enabled);
};

// RAII guard to switch the way JIT emits the bytecode for inputs.
Expand All @@ -216,24 +219,32 @@ struct TORCH_API BytecodeEmitMode {
struct TORCH_API BytecodeEmitModeGuard {
BytecodeEmitModeGuard(
bool enable_default_value_for_unspecified_arg,
bool enable_default_args_before_out_args)
bool enable_default_args_before_out_args,
bool enable_emit_promoted_ops)
: prev_default_value_for_unspecified_arg_mode(
BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled()),
prev_default_args_before_out_args(
BytecodeEmitMode::is_default_args_before_out_args_enabled()) {
BytecodeEmitMode::is_default_args_before_out_args_enabled()),
prev_default_emit_promoted_ops(
BytecodeEmitMode::is_emit_promoted_ops_enabled()) {
BytecodeEmitMode::set_default_value_for_unspecified_arg_enabled(
enable_default_value_for_unspecified_arg);
BytecodeEmitMode::set_default_args_before_out_args_enabled(
enable_default_args_before_out_args);
BytecodeEmitMode::set_default_emit_promoted_ops_enabled(
enable_emit_promoted_ops);
}
~BytecodeEmitModeGuard() {
BytecodeEmitMode::set_default_value_for_unspecified_arg_enabled(
prev_default_value_for_unspecified_arg_mode);
BytecodeEmitMode::set_default_args_before_out_args_enabled(
prev_default_args_before_out_args);
BytecodeEmitMode::set_default_emit_promoted_ops_enabled(
prev_default_emit_promoted_ops);
}
bool prev_default_value_for_unspecified_arg_mode;
bool prev_default_args_before_out_args;
bool prev_default_emit_promoted_ops;
};

TORCH_API IValue to_tuple(std::vector<IValue> ivalues);
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/serialization/export_bytecode.cpp
Expand Up @@ -142,7 +142,8 @@ mobile::Code compileGraphToMobileCode(
graph,
name,
compilation_options.enable_default_value_for_unspecified_arg,
compilation_options.enable_default_args_before_out_args);
compilation_options.enable_default_args_before_out_args,
compilation_options.enable_emit_promoted_ops);

mobile::Code mobile_code;

Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/serialization/export_bytecode.h
Expand Up @@ -20,6 +20,7 @@ struct TORCH_API CompilationOptions {
bool incl_interface_call = false;
bool enable_default_value_for_unspecified_arg = false;
bool enable_default_args_before_out_args = true;
bool enable_emit_promoted_ops = true;
int model_version = caffe2::serialize::kProducedBytecodeVersion;
};

Expand Down
11 changes: 11 additions & 0 deletions torch/csrc/jit/serialization/export_module.cpp
Expand Up @@ -44,6 +44,8 @@ CompilationOptions getOptionsFromGlobal() {
BytecodeEmitMode::is_default_args_before_out_args_enabled();
compilation_options.enable_default_value_for_unspecified_arg =
BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled();
compilation_options.enable_emit_promoted_ops =
BytecodeEmitMode::is_emit_promoted_ops_enabled();
compilation_options.incl_interface_call = getMobileInterfaceCallExport();
compilation_options.model_version =
caffe2::serialize::kProducedBytecodeVersion;
Expand Down Expand Up @@ -851,5 +853,14 @@ void BytecodeEmitMode::set_default_args_before_out_args_enabled(bool enabled) {
emitDefautlArgsWithOutArgs = enabled;
}

thread_local bool emitDefaultEmitPromotedOps =
caffe2::serialize::kProducedBytecodeVersion <= 7 ? false : true;
bool BytecodeEmitMode::is_emit_promoted_ops_enabled() {
return emitDefaultEmitPromotedOps;
}
void BytecodeEmitMode::set_default_emit_promoted_ops_enabled(bool enabled) {
emitDefaultEmitPromotedOps = enabled;
}

} // namespace jit
} // namespace torch

0 comments on commit f8ff95f

Please sign in to comment.