Skip to content

Commit

Permalink
Generalize constant_table from tensor only to ivalue
Browse files Browse the repository at this point in the history
Currently only constant except tensor must be inlined during serialization.
Tensor are stored in the contant table. This patch generalizes this capability
to any IValue. This is particularly useful for non ASCII string literal that
cannot be inlined.
  • Loading branch information
bzinodev committed May 18, 2020
1 parent 8292742 commit 29540c3
Show file tree
Hide file tree
Showing 13 changed files with 125 additions and 83 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/test/type_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ static TypePtr importType(
std::shared_ptr<CompilationUnit> cu,
const std::string& qual_name,
const std::string& src) {
std::vector<at::Tensor> constantTable;
std::vector<at::IValue> constantTable;
auto source = std::make_shared<torch::jit::Source>(src);
torch::jit::SourceImporter si(
cu,
Expand Down
8 changes: 4 additions & 4 deletions test/cpp/jit/test_class_import.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ static void import_libs(
std::shared_ptr<CompilationUnit> cu,
const std::string& class_name,
const std::shared_ptr<Source>& src,
const std::vector<at::Tensor>& tensor_table) {
const std::vector<at::IValue>& tensor_table) {
SourceImporter si(
cu,
&tensor_table,
Expand All @@ -48,7 +48,7 @@ static void import_libs(
void testClassImport() {
auto cu1 = std::make_shared<CompilationUnit>();
auto cu2 = std::make_shared<CompilationUnit>();
std::vector<at::Tensor> constantTable;
std::vector<at::IValue> constantTable;
// Import different versions of FooTest into two namespaces.
import_libs(
cu1,
Expand Down Expand Up @@ -83,7 +83,7 @@ void testClassImport() {
void testScriptObject() {
Module m1("m1");
Module m2("m2");
std::vector<at::Tensor> constantTable;
std::vector<at::IValue> constantTable;
import_libs(
m1._ivalue()->compilation_unit(),
"__torch__.FooTest",
Expand Down Expand Up @@ -144,7 +144,7 @@ class FooBar1234(Module):

void testSaveLoadTorchbind() {
auto cu1 = std::make_shared<CompilationUnit>();
std::vector<at::Tensor> constantTable;
std::vector<at::IValue> constantTable;
// Import different versions of FooTest into two namespaces.
import_libs(
cu1,
Expand Down
4 changes: 2 additions & 2 deletions test/cpp/jit/test_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ static void import_libs(
std::shared_ptr<CompilationUnit> cu,
const std::string& class_name,
const std::shared_ptr<Source>& src,
const std::vector<at::Tensor>& tensor_table) {
const std::vector<at::IValue>& tensor_table) {
SourceImporter si(
cu,
&tensor_table,
Expand All @@ -49,7 +49,7 @@ void testModuleInterfaceSerialization() {
Module parentMod("parentMod", cu);
Module subMod("subMod", cu);

std::vector<at::Tensor> constantTable;
std::vector<at::IValue> constantTable;
import_libs(
cu,
"__torch__.OneForward",
Expand Down
52 changes: 31 additions & 21 deletions test/jit/test_freezing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,48 +16,58 @@ def test_freeze_module(self):
class M(nn.Module):
def __init__(self):
super(M, self).__init__()
self.a = 1 # folded
self.b = 1.2 # folded
self.c = "hello" # folded
self.d = [1, 1] # folded
self.e = [1.0, 1.1] # folded
self.f = ["hello", "world"] # folded
self.a = 1 # folded
self.b = 1.2 # folded
self.c = "hello" # folded
self.c2 = "hi\xA1" # not folded
self.d = [1, 1] # folded
self.e = [1.0, 1.1] # folded
self.f = ["hello", "world"] # folded
self.f2 = [(1, "Over \u0e55\u0e57 57")]
self.g = ([1, 2], 3.2, "4.4", torch.tensor([5.5], requires_grad=True)) # folded
self.h = {"layer" : [torch.tensor([7.7], requires_grad=True)]}
self.h2 = {"layer\xB1" : [torch.tensor([8.8], requires_grad=True)]}
self.t = torch.tensor([1.2, 2.4], requires_grad=True) # folded
self.ts = [torch.tensor([1.0, 2.0], requires_grad=True), torch.tensor([3.0, 4.0], requires_grad=True)] # folded
self.tt = [[torch.tensor([3.3, 2.3], requires_grad=True), None]]

def forward(self, x):
return str(self.a) + str(self.b) + self.c + str(self.d) + \
str(self.e) + str(self.f) + str(self.g) + str(self.h['layer']) + str(self.t) + str(self.ts) + str(self.tt)
return str(self.a) + str(self.b) + self.c + self.c2 + str(self.d) + \
str(self.e) + str(self.f) + str(self.f2) + str(self.g) + \
str(self.h) + str(self.h2) + str(self.t) + str(self.ts) + str(self.tt)


m = torch.jit.script(M())
m.eval()
input = torch.randn(2, 2)
output_s = m.forward(input)
m._c = torch._C._freeze_module(m._c)

buffer = io.BytesIO()
torch.jit.save(m._c, buffer)
buffer.seek(0)
m2 = torch.jit.load(buffer)
# Check if frozen module looks as below:
# module m {
# attributes {
# tt = ...
# }
# ...
# }
self.assertFalse(m._c.hasattr('a'))
self.assertFalse(m._c.hasattr('b'))
self.assertFalse(m._c.hasattr('c'))
self.assertFalse(m._c.hasattr('d'))
self.assertFalse(m._c.hasattr('e'))
self.assertFalse(m._c.hasattr('f'))
self.assertFalse(m._c.hasattr('g'))
self.assertFalse(m._c.hasattr('h'))
self.assertFalse(m._c.hasattr('t'))
self.assertFalse(m._c.hasattr('ts'))
self.assertFalse(m._c.hasattr('tt'))
output_f = m.forward(input)
self.assertFalse(m2._c.hasattr('a'))
self.assertFalse(m2._c.hasattr('b'))
self.assertFalse(m2._c.hasattr('c'))
self.assertFalse(m2._c.hasattr('c2'))
self.assertFalse(m2._c.hasattr('d'))
self.assertFalse(m2._c.hasattr('e'))
self.assertFalse(m2._c.hasattr('f'))
self.assertFalse(m2._c.hasattr('f2'))
self.assertFalse(m2._c.hasattr('g'))
self.assertFalse(m2._c.hasattr('h'))
self.assertFalse(m2._c.hasattr('h2'))
self.assertFalse(m2._c.hasattr('t'))
self.assertFalse(m2._c.hasattr('ts'))
self.assertFalse(m2._c.hasattr('tt'))
output_f = m2.forward(input)
self.assertEqual(output_s, output_f)

def test_freeze_module_with_submodule(self):
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/jit_log.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ bool is_enabled(const char* cfname, JitLoggingLevels level) {
// a dummy function to give to PythonPrint
std::string log_function(const std::shared_ptr<torch::jit::Graph>& graph) {
torch::jit::GraphFunction func("source_dump", graph, nullptr);
std::vector<at::Tensor> tensors;
std::vector<at::IValue> constants;
std::vector<c10::NamedTypePtr> deps;
PythonPrint pp(tensors, deps);
PythonPrint pp(constants, deps);
pp.printFunction(func);
return pp.str();
}
Expand Down
32 changes: 16 additions & 16 deletions torch/csrc/jit/python/script_init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -988,23 +988,23 @@ void initJitScriptBindings(PyObject* module) {
.def_property_readonly(
"code",
[](Module& self) {
std::vector<at::Tensor> tensors;
std::vector<at::IValue> constants;
std::vector<c10::NamedTypePtr> deps;
PythonPrint pp(tensors, deps);
PythonPrint pp(constants, deps);
pp.printNamedType(self.type());
return pp.str();
})
.def_property_readonly(
"code_with_constants",
[](Module& self) {
std::vector<at::Tensor> tensors;
std::vector<at::IValue> constants;
std::vector<c10::NamedTypePtr> deps;
PythonPrint pp(tensors, deps);
PythonPrint pp(constants, deps);
pp.printNamedType(self.type());
std::map<std::string, at::Tensor> consts;
std::map<std::string, at::IValue> consts;
int i = 0;
for (auto const& tensor : tensors) {
consts["c" + std::to_string(i)] = tensor;
for (auto const& constant : constants) {
consts["c" + std::to_string(i)] = constant;
i += 1;
}
return std::make_tuple(pp.str(), consts);
Expand Down Expand Up @@ -1108,10 +1108,10 @@ void initJitScriptBindings(PyObject* module) {
.def_property_readonly(
"code",
[](const StrongFunctionPtr& self) {
std::vector<at::Tensor> tensors;
std::vector<at::IValue> constants;
std::vector<c10::NamedTypePtr> deps;

PythonPrint pp(tensors, deps);
PythonPrint pp(constants, deps);
pp.printFunction(*self.function_);
return pp.str();
})
Expand Down Expand Up @@ -1153,21 +1153,21 @@ void initJitScriptBindings(PyObject* module) {
.def_property_readonly(
"code",
[](Method& self) {
std::vector<at::Tensor> tensors;
std::vector<at::IValue> constants;
std::vector<c10::NamedTypePtr> deps;
PythonPrint pp(tensors, deps);
PythonPrint pp(constants, deps);
pp.printMethod(self.function());
return pp.str();
})
.def_property_readonly("code_with_constants", [](Method& self) {
std::vector<at::Tensor> tensors;
std::vector<at::IValue> constants;
std::vector<c10::NamedTypePtr> deps;
PythonPrint pp(tensors, deps);
PythonPrint pp(constants, deps);
pp.printMethod(self.function());
std::map<std::string, at::Tensor> consts;
std::map<std::string, at::IValue> consts;
int i = 0;
for (auto const& tensor : tensors) {
consts["c" + std::to_string(i)] = tensor;
for (auto const& constant : constants) {
consts["c" + std::to_string(i)] = constant;
i += 1;
}
return std::make_tuple(pp.str(), consts);
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/serialization/export_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ class ScriptModuleSerializer {
}

caffe2::serialize::PyTorchStreamWriter writer_;
std::vector<at::Tensor> constant_table_;
std::vector<at::IValue> constant_table_;
std::unordered_set<c10::NamedTypePtr> converted_types_;
std::vector<c10::NamedTypePtr> class_deps_;
TypeNameUniquer type_name_uniquer_;
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/serialization/import.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class ScriptModuleDeserializer final {
std::shared_ptr<CompilationUnit> compilation_unit_;
std::unique_ptr<PyTorchStreamReader> reader_;
c10::optional<at::Device> device_;
std::vector<at::Tensor> constants_table_;
std::vector<at::IValue> constants_table_;
SourceImporter source_importer_;
std::string export_prefix_ = "code/";
};
Expand Down Expand Up @@ -264,7 +264,7 @@ Module ScriptModuleDeserializer::deserialize(
}
auto tuple = readArchive("constants").toTuple();
for (auto constant : tuple->elements()) {
constants_table_.push_back(constant.toTensor());
constants_table_.push_back(constant.toIValue());
}
auto m = Module(readArchive("data").toObject());
rewriteQuantizedConvForBC(m);
Expand Down
19 changes: 13 additions & 6 deletions torch/csrc/jit/serialization/import_legacy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,17 @@ class ScriptModuleDeserializer final {
device_(device),
source_importer_(
compilation_unit_,
&constants_table_,
&constant_table_,
[this](const std::string& qualifier) {
return findSourceInArchiveFromQualifier(
*reader_, export_prefix_, qualifier);
},
reader_->version()) {}
reader_->version()) {
for (auto& constant : constant_table_) {
TORCH_INTERNAL_ASSERT(constant.isTensor(), " expected a tensor");
tensor_table_.emplace_back(std::move(constant).toTensor());
}
}

Module LEGACY_deserialize();

Expand All @@ -73,7 +78,9 @@ class ScriptModuleDeserializer final {
std::shared_ptr<CompilationUnit> compilation_unit_;
std::unique_ptr<PyTorchStreamReader> reader_;
c10::optional<at::Device> device_;
std::vector<at::Tensor> constants_table_;
// Legacy only tensor can be a constant.
std::vector<at::IValue> constant_table_;
std::vector<at::Tensor> tensor_table_;
SourceImporter source_importer_;
std::string export_prefix_ = "code/";
};
Expand Down Expand Up @@ -140,15 +147,15 @@ IValue ScriptModuleDeserializer::LEGACY_loadPickleArchive(
auto cls = source_importer_.loadType(qn)->expect<ClassType>();
return c10::StrongTypePtr(compilation_unit_, std::move(cls));
},
&constants_table_);
&tensor_table_);
return ivalue;
}

void ScriptModuleDeserializer::LEGACY_loadTensorTable(
torch::ModelDef* model_def) {
std::unordered_map<std::string, at::Storage> storageMap;
for (const torch::TensorDef& tensor : model_def->tensors()) {
constants_table_.emplace_back(LEGACY_loadTensor(tensor, storageMap));
tensor_table_.emplace_back(LEGACY_loadTensor(tensor, storageMap));
}
}

Expand Down Expand Up @@ -284,7 +291,7 @@ Module ScriptModuleDeserializer::LEGACY_convertModule(
}
for (int i = 0; i < module_def.parameters_size(); ++i) {
const torch::ParameterDef& param_def = module_def.parameters(i);
at::Tensor tensor = constants_table_.at(param_def.tensor_id());
at::Tensor tensor = tensor_table_.at(param_def.tensor_id());
if (param_def.is_buffer()) {
module.register_buffer(param_def.name(), tensor);
} else {
Expand Down
12 changes: 6 additions & 6 deletions torch/csrc/jit/serialization/import_source.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ struct TORCH_API ClassNamespaceValue : public SugaredValue {
// in the 'constants' vector. This table is will be stored in a container format
// and given to the import_method when restoring the code.
struct ConstantTableValue : public SugaredValue {
ConstantTableValue(const std::vector<at::Tensor>* constants)
ConstantTableValue(const std::vector<at::IValue>* constants)
: constants_(constants) {}
std::string kind() const override {
return "CONSTANTS";
Expand Down Expand Up @@ -86,14 +86,14 @@ struct ConstantTableValue : public SugaredValue {
}

private:
const std::vector<at::Tensor>* constants_;
const std::vector<at::IValue>* constants_;
};

struct SourceImporterImpl : public Resolver,
std::enable_shared_from_this<SourceImporterImpl> {
SourceImporterImpl(
const std::shared_ptr<CompilationUnit> cu,
const std::vector<at::Tensor>* tensor_table,
const std::vector<at::IValue>* constant_table,
SourceLoader source_loader,
size_t version)
: cu_(cu), source_loader_(std::move(source_loader)) {
Expand All @@ -102,7 +102,7 @@ struct SourceImporterImpl : public Resolver,
{"ops", std::make_shared<OpsValue>(version)},
// Constants present in the model. Used to resolve "CONSTANTS.n" to the
// actual value
{"CONSTANTS", std::make_shared<ConstantTableValue>(tensor_table)},
{"CONSTANTS", std::make_shared<ConstantTableValue>(constant_table)},
{"fork", SpecialFormValue::create(prim::fork)},
{"annotate", SpecialFormValue::create(prim::annotate)},
{"unchecked_cast", SpecialFormValue::create(prim::unchecked_cast)},
Expand Down Expand Up @@ -556,12 +556,12 @@ std::shared_ptr<SugaredValue> ClassNamespaceValue::attr(
SourceImporter::SourceImporter(
// The compilation unit that will own the imported source
std::shared_ptr<CompilationUnit> cu,
const std::vector<at::Tensor>* tensor_table,
const std::vector<IValue>* constant_table,
SourceLoader loader,
size_t version)
: pImpl(std::make_shared<SourceImporterImpl>(
std::move(cu),
tensor_table,
constant_table,
std::move(loader),
version)) {}

Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/serialization/import_source.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ struct TORCH_API SourceImporter {
SourceImporter(
// The compilation unit that will own the imported source
std::shared_ptr<CompilationUnit> cu,
const std::vector<at::Tensor>* tensor_table,
const std::vector<at::IValue>* constant_table,
SourceLoader loader,
size_t version);

Expand Down

0 comments on commit 29540c3

Please sign in to comment.