Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions binaries/lite_interpreter_model_load.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#include "ATen/ATen.h"
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/serialization/import.h>
#include "torch/script.h"

C10_DEFINE_string(model, "", "The given bytecode model to check if it is supported by lite_interpreter.");

int main(int argc, char** argv) {
c10::SetUsageMessage(
"Check if exported bytecode model is runnable by lite_interpreter.\n"
"Example usage:\n"
"./lite_interpreter_model_load"
" --model=<model_file>");

if (!c10::ParseCommandLineFlags(&argc, &argv)) {
std::cerr << "Failed to parse command line flags!" << std::endl;
return 1;
}

if (FLAGS_model.empty()) {
std::cerr << FLAGS_model << ":Model file is not provided\n";
return -1;
}

torch::jit::mobile::Module bc = torch::jit::_load_for_mobile(FLAGS_model);
return 0;
}
7 changes: 5 additions & 2 deletions torch/csrc/jit/mobile/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ void Function::append_instruction(OpCode op, int X, int N) {
code_->instructions_.emplace_back(op, X, N);
}

void Function::append_operator(const std::string& name,
bool Function::append_operator(const std::string& name,
const std::string& overload_name) {
// Keep the original opname in code_
code_->op_names_.emplace_back(name, overload_name);
Expand All @@ -29,13 +29,16 @@ void Function::append_operator(const std::string& name,
opname.name = "_" + opname.name;
}
auto op = c10::Dispatcher::singleton().findSchema(opname);
TORCH_CHECK(op.has_value(), opname.name, ".", opname.overload_name, " cannot be found.");
if (!op.has_value()) {
return false;
}
// TODO: operator.h now does not depend on Node* so we can also look up operators from
// that registry for use in mobile as a way to share implementations.
auto fn = [op](Stack& stack) {
c10::Dispatcher::singleton().callBoxed(*op, &stack);
};
code_->operators_.emplace_back(fn);
return true;
}

void Function::append_constant(const c10::IValue& constant) {
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/mobile/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class Function{
const std::string& name() const;
const c10::QualifiedName& qualname() const;
void append_instruction(OpCode op, int X, int N);
void append_operator(const std::string& name,
bool append_operator(const std::string& name,
const std::string& overload_name);
void append_constant(const c10::IValue& constant);
void append_type(const c10::TypePtr& type);
Expand Down
19 changes: 18 additions & 1 deletion torch/csrc/jit/mobile/import.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ IValue expect_field(IValue tup, const std::string& expected_name, size_t entry){
return row->elements().at(1);
}

void print_unsupported_ops_and_throw(const std::unordered_set<std::string>& unsupported_ops) {
std::string error_message("{");
for (const auto& op_name : unsupported_ops) {
error_message += op_name + ", ";
}
error_message += "}";
TORCH_CHECK(false, "Following ops cannot be found:", error_message);
}

void parseMethods(const std::vector<IValue>& vals, mobile::CompilationUnit& mcu) {
for (const auto& element : vals) {
const auto& m_tuple = element.toTuple()->elements();
Expand All @@ -72,14 +81,22 @@ void parseMethods(const std::vector<IValue>& vals, mobile::CompilationUnit& mcu)
function->append_instruction(op_code, X, N);
}

std::unordered_set<std::string> unsupported_op_names;
for (const auto& op : ops_list) {
auto op_item = op.toTuple()->elements();
TORCH_CHECK(op_item.size() == 2,
"There should be two parts in an operator name.");
function->append_operator(op_item[0].toString()->string(),
auto op_found = function->append_operator(op_item[0].toString()->string(),
op_item[1].toString()->string());
if (!op_found) {
unsupported_op_names.emplace(op_item[0].toString()->string() + "." + op_item[1].toString()->string());
}
}

if (!unsupported_op_names.empty()) {
print_unsupported_ops_and_throw(unsupported_op_names);
};

for (const auto& constant : consts_list) {
function->append_constant(constant);
}
Expand Down