Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions test/cpp/jit/test_lite_interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/serialization/export.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/custom_class.h>
#include <torch/torch.h>
Expand Down Expand Up @@ -793,6 +794,35 @@ TEST(LiteInterpreterTest, ExtraFiles) {
ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
}

TEST(LiteInterpreterTest, OpNameExportFetchRootOperators) {
torch::jit::Module m("m");
m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
m.register_parameter("bias", torch::ones({20}), false);
m.define(R"(
def forward(self, input):
x1 = torch.zeros(2, 2)
x2 = torch.empty_like(torch.empty(2, 2))
x3 = torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
return (x1, x2, x3)
)");
m.eval();

std::stringstream ss;
m._save_for_mobile(ss);

torch::jit::mobile::Module ptl_model = torch::jit::_load_for_mobile(ss);
std::set<std::string> operator_names =
torch::jit::mobile::_export_operator_list(ptl_model);
std::set<std::string> expected_operator_names = {
"aten::_convolution",
"aten::empty.memory_format",
"aten::empty_like",
"aten::zeros",
};
EXPECT_EQ(operator_names, expected_operator_names)
<< "Expected the root operator lists to be the same";
}

namespace {
static auto reg =
torch::class_<TorchBindLiteInterpreterTestStruct>(
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/mobile/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ c10::IValue Function::operator()(Stack& stack) const {
return stack.front();
}

const std::shared_ptr<Code> Function::get_code() const {
return code_;
}

} // namespace mobile
} // namespace jit
} // namespace torch
1 change: 1 addition & 0 deletions torch/csrc/jit/mobile/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class Function {
void set_register_size(size_t size);

std::string get_module_debug_info(size_t pc) const;
const std::shared_ptr<Code> get_code() const;

void setSchema(c10::FunctionSchema schema);
const at::optional<c10::FunctionSchema>& getSchema() const;
Expand Down
19 changes: 17 additions & 2 deletions torch/csrc/jit/mobile/import.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ class BytecodeDeserializer final {
private:
TypePtr resolveTypeName(const c10::QualifiedName& qn);
void parseMethods(
const std::vector<IValue>& vals,
const std::vector<IValue>&
vals, // vals is a list of all methods in the model.
const c10::optional<std::vector<IValue>>& debug_info_vals,
mobile::CompilationUnit& mcu);
c10::IValue readArchive(
Expand Down Expand Up @@ -159,7 +160,8 @@ TypePtr BytecodeDeserializer::resolveTypeName(const c10::QualifiedName& qn) {
}

void BytecodeDeserializer::parseMethods(
const std::vector<IValue>& vals,
const std::vector<IValue>&
vals, // vals is a list of all methods in the model.
const c10::optional<std::vector<IValue>>& debug_info_vals,
mobile::CompilationUnit& mcu) {
TORCH_CHECK(vals.size() > 0, "Bytecode has no elements. ");
Expand Down Expand Up @@ -190,6 +192,7 @@ void BytecodeDeserializer::parseMethods(
"The numbers of bytecode values and debug info values do not match.");
}

// Process all methods in this mobile module.
for (size_t i = method_i_start; i < vals.size(); ++i) {
const auto& element = vals[i];
const auto& m_tuple = element.toTuple()->elements();
Expand Down Expand Up @@ -263,6 +266,8 @@ void BytecodeDeserializer::parseMethods(
}

std::unordered_set<std::string> unsupported_op_names;
// ops_list is the list of operator names that were read in from
// bytecode.plk for the method that is currently being processed.
for (const auto& op : ops_list) {
auto op_item = op.toTuple()->elements();
TORCH_CHECK(
Expand Down Expand Up @@ -358,6 +363,16 @@ mobile::Module BytecodeDeserializer::deserialize(
}
}
auto mcu = std::make_shared<mobile::CompilationUnit>();

// bvals can have 2 possible formats:
//
// 1. Old format: bvals is an array (Tuple) of N elements, each element being
// itself a Tuple(method_name, method_table).
//
// 2. New format: bvals is an array (Tuple) of 1+N elements. The first element
// being a Tuple (int, table), and the integer stands for the bytecode version
// number. The rest of the elements are the same as before.
//
auto bvals = readArchive("bytecode", mcu).toTuple()->elements();

c10::optional<std::vector<IValue>> debug_info_bvals;
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/mobile/import.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ using caffe2::serialize::ReadAdapterInterface;
using ExtraFilesMap = std::unordered_map<std::string, std::string>;
static ExtraFilesMap default_extra_files_mobile;

// The family of methods below convery a serialized Mobile Module
// into a mobile::Module object.
TORCH_API mobile::Module _load_for_mobile(
std::istream& in,
c10::optional<at::Device> device = c10::nullopt,
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/mobile/method.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#pragma once

#include <ATen/core/ivalue.h>
#include <torch/csrc/jit/mobile/function.h>

Expand Down
8 changes: 8 additions & 0 deletions torch/csrc/jit/mobile/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,14 @@ bool Module::is_training() const {
return true;
}

const std::vector<Method> Module::get_methods() const {
std::vector<Method> methods;
for (std::unique_ptr<Function>& fn : cu_->methods()) {
methods.emplace_back(this, fn.get());
}
return methods;
}

Method::Method(const Module* owner, Function* function)
: owner_(owner), function_(function) {}

Expand Down
30 changes: 30 additions & 0 deletions torch/csrc/jit/mobile/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,27 @@ namespace jit {
namespace mobile {
using Stack = std::vector<c10::IValue>;

// A CompilationUnit object is the one that gets executed by the lite
// interpreter.
//
// A CompilationUnit object contains a list of Method Objects. These are methods
// that appear in the original PyTorch Model. These method correspond to Python
// member functions of the Model class.
//
// Methods in turn contain a Function, and a back-pointer to the Module that
// owns this Method instance.
//
// A Function contains a Code Object (code_) which is defined in interpreter.h
//
// A Code object contains the following:
//
// std::vector<Instruction> instructions_;
// std::vector<c10::OperatorName> op_names_;
// std::vector<std::function<void(Stack&)>> operators_;
// std::vector<c10::IValue> constants_;
// std::vector<c10::TypePtr> types_;
// size_t register_size_; // Aggregated output size.
//
class CompilationUnit {
public:
void register_function(std::unique_ptr<Function> fn);
Expand All @@ -19,6 +40,14 @@ class CompilationUnit {
std::vector<std::unique_ptr<Function>> methods_;
};

// A Torch Mobile Module is a representation of the model (trained in case
// of inference). A Mobile Module contains
//
// 1. data (object_)
// 2. metadata (optional) about the model (metadata_ from the metadata.pkl
// file added after training)
// 3. Compilation Unit (cu_)
//
class TORCH_API Module {
public:
Module(
Expand Down Expand Up @@ -65,6 +94,7 @@ class TORCH_API Module {
const std::unordered_map<std::string, std::string> metadata() const {
return metadata_;
}
const std::vector<Method> get_methods() const;

c10::IValue attr(const std::string& name, c10::IValue or_else) const {
if (auto r = object_->type()->findAttributeSlot(name)) {
Expand Down
17 changes: 17 additions & 0 deletions torch/csrc/jit/serialization/export.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,5 +111,22 @@ TORCH_API void SetExportModuleMobileInfoConverter(
// Returns a list of names of all operators in the module and its submodules.
TORCH_API std::vector<std::string> export_opnames(const Module& m);

namespace mobile {

class Module;
/**
* Given a torch::jit::mobile::Module, return a set of operator names
* (with overload name) that are used by any method in this mobile
* Mobile. This method runs through the bytecode for all methods
* in the specified model (module), and extracts all the root
* operator names. Root operators are operators that are called
* directly by the model (as opposed to non-root operators, which
* may be called transitively by the root operators).
*
*/
TORCH_API std::set<std::string> _export_operator_list(
torch::jit::mobile::Module& module);

} // namespace mobile
} // namespace jit
} // namespace torch
25 changes: 25 additions & 0 deletions torch/csrc/jit/serialization/export_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
#include <torch/csrc/jit/ir/attributes.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/type_hashing.h>
#include <torch/csrc/jit/mobile/function.h>
#include <torch/csrc/jit/mobile/interpreter.h>
#include <torch/csrc/jit/mobile/method.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/serialization/import_export_constants.h>
Expand Down Expand Up @@ -631,5 +635,26 @@ std::vector<std::string> export_opnames(const script::Module& m) {
return std::vector<std::string>(names.begin(), names.end());
}

namespace mobile {

std::set<std::string> _export_operator_list(
torch::jit::mobile::Module& module) {
std::set<std::string> operator_list;
for (Method func : module.get_methods()) {
const Function& function = func.function();
const std::shared_ptr<Code> cptr = function.get_code();
// op_names below isn't a list of unique operator names. In fact
// it can contain the same operator name many many times, so we need
// to de-dup the list by adding all the operator names into
// an std::set<std::string>.
std::vector<c10::OperatorName> const& op_names = cptr->op_names_;
for (auto& op_name : op_names) {
operator_list.insert(toString(op_name));
}
}
return operator_list;
}

} // namespace mobile
} // namespace jit
} // namespace torch