Skip to content

Commit

Permalink
Generalize constant_table from tensor only to ivalue
Browse files Browse the repository at this point in the history
Currently only constant except tensor must be inlined during serialization.
Tensor are stored in the contant table. This patch generalizes this capability
to any IValue. This is particularly useful for non ASCII string literal that
cannot be inlined.
  • Loading branch information
bzinodev committed May 29, 2020
1 parent 1d0ec50 commit 35f059c
Show file tree
Hide file tree
Showing 17 changed files with 190 additions and 80 deletions.
40 changes: 40 additions & 0 deletions aten/src/ATen/core/ivalue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,46 @@ TypePtr IValue::type() const {
TORCH_INTERNAL_ASSERT(false, "unhandled case in IValue::type()");
}

void IValue::visit(const std::function<bool (const IValue &)>& visitor) const {
if (visitor(*this)) {
// Short cut.
return;
}
switch (this->tag) {
case Tag::Tuple:
case Tag::GenericList: {
c10::ArrayRef<IValue> elems;
if (isTuple()) {
elems = this->toTuple()->elements();
} else {
elems = this->toListRef();
}
for (auto& elem : elems) {
elem.visit(visitor);
}
break;
}
case Tag::GenericDict:
for (const auto& pair : this->toGenericDict()) {
pair.value().visit(visitor);
pair.key().visit(visitor);
}
break;
case Tag::Object: {
auto obj_type = type()->expect<ClassType>();
auto obj_value = toObject();
auto attributes = obj_type->getAttributes();
for (const auto& attr: attributes) {
auto attribute = obj_value->getAttr(attr.getName());
attribute.visit(visitor);
}
break;
}
default:
break;
}
}

void IValue::getSubValues(HashAliasedIValues& subValues) const {
switch (this->tag) {
case Tag::Tensor:
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/core/ivalue.h
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,10 @@ struct CAFFE2_API IValue final {
// Inserts all subvalues of this in subValues.
void getSubValues(HashAliasedIValues& subValues) const;

// Apply visitor to every subvalue.
// TODO: There are several places that recurse over IValue. This is fragile.
// This visitor should be used to recurse over ivalues.
void visit(const std::function<bool (const IValue &)>& visitor) const;
IValue deepcopy() const;
IValue deepcopy(
HashAliasedIValueMap& memo) const;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/test/type_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ static TypePtr importType(
std::shared_ptr<CompilationUnit> cu,
const std::string& qual_name,
const std::string& src) {
std::vector<at::Tensor> constantTable;
std::vector<at::IValue> constantTable;
auto source = std::make_shared<torch::jit::Source>(src);
torch::jit::SourceImporter si(
cu,
Expand Down
8 changes: 4 additions & 4 deletions test/cpp/jit/test_class_import.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ static void import_libs(
std::shared_ptr<CompilationUnit> cu,
const std::string& class_name,
const std::shared_ptr<Source>& src,
const std::vector<at::Tensor>& tensor_table) {
const std::vector<at::IValue>& tensor_table) {
SourceImporter si(
cu,
&tensor_table,
Expand All @@ -48,7 +48,7 @@ static void import_libs(
void testClassImport() {
auto cu1 = std::make_shared<CompilationUnit>();
auto cu2 = std::make_shared<CompilationUnit>();
std::vector<at::Tensor> constantTable;
std::vector<at::IValue> constantTable;
// Import different versions of FooTest into two namespaces.
import_libs(
cu1,
Expand Down Expand Up @@ -83,7 +83,7 @@ void testClassImport() {
void testScriptObject() {
Module m1("m1");
Module m2("m2");
std::vector<at::Tensor> constantTable;
std::vector<at::IValue> constantTable;
import_libs(
m1._ivalue()->compilation_unit(),
"__torch__.FooTest",
Expand Down Expand Up @@ -144,7 +144,7 @@ class FooBar1234(Module):

void testSaveLoadTorchbind() {
auto cu1 = std::make_shared<CompilationUnit>();
std::vector<at::Tensor> constantTable;
std::vector<at::IValue> constantTable;
// Import different versions of FooTest into two namespaces.
import_libs(
cu1,
Expand Down
4 changes: 2 additions & 2 deletions test/cpp/jit/test_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ static void import_libs(
std::shared_ptr<CompilationUnit> cu,
const std::string& class_name,
const std::shared_ptr<Source>& src,
const std::vector<at::Tensor>& tensor_table) {
const std::vector<at::IValue>& tensor_table) {
SourceImporter si(
cu,
&tensor_table,
Expand All @@ -49,7 +49,7 @@ void testModuleInterfaceSerialization() {
Module parentMod("parentMod", cu);
Module subMod("subMod", cu);

std::vector<at::Tensor> constantTable;
std::vector<at::IValue> constantTable;
import_libs(
cu,
"__torch__.OneForward",
Expand Down
1 change: 0 additions & 1 deletion test/cpp/jit/test_misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,6 @@ void checkTracedInputs(const TracedTestInputs& inputs) {
TORCH_CHECK(found_mul);
}


void checkScopeCallbacks() {
bool found_function_scope = false;
bool found_method_scope = false;
Expand Down
56 changes: 33 additions & 23 deletions test/jit/test_freezing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import unittest
import torch
import torch.nn as nn
from torch.testing._internal.jit_utils import JitTestCase
Expand All @@ -12,52 +13,63 @@
"instead.")

class TestFreezing(JitTestCase):
@unittest.skip("temporarily disable the test for fwd compatibility")
def test_freeze_module(self):
class M(nn.Module):
def __init__(self):
super(M, self).__init__()
self.a = 1 # folded
self.b = 1.2 # folded
self.c = "hello" # folded
self.d = [1, 1] # folded
self.e = [1.0, 1.1] # folded
self.f = ["hello", "world"] # folded
self.a = 1 # folded
self.b = 1.2 # folded
self.c = "hello" # folded
self.c2 = "hi\xA1" # not folded
self.d = [1, 1] # folded
self.e = [1.0, 1.1] # folded
self.f = ["hello", "world"] # folded
self.f2 = [(1, "Over \u0e55\u0e57 57")]
self.g = ([1, 2], 3.2, "4.4", torch.tensor([5.5], requires_grad=True)) # folded
self.h = {"layer" : [torch.tensor([7.7], requires_grad=True)]}
self.h2 = {"layer\xB1" : [torch.tensor([8.8], requires_grad=True)]}
self.t = torch.tensor([1.2, 2.4], requires_grad=True) # folded
self.ts = [torch.tensor([1.0, 2.0], requires_grad=True), torch.tensor([3.0, 4.0], requires_grad=True)] # folded
self.tt = [[torch.tensor([3.3, 2.3], requires_grad=True), None]]

def forward(self, x):
return str(self.a) + str(self.b) + self.c + str(self.d) + \
str(self.e) + str(self.f) + str(self.g) + str(self.h['layer']) + str(self.t) + str(self.ts) + str(self.tt)
return str(self.a) + str(self.b) + self.c + self.c2 + str(self.d) + \
str(self.e) + str(self.f) + str(self.f2) + str(self.g) + \
str(self.h) + str(self.h2) + str(self.t) + str(self.ts) + str(self.tt)


m = torch.jit.script(M())
m.eval()
input = torch.randn(2, 2)
output_s = m.forward(input)
m._c = torch._C._freeze_module(m._c)

buffer = io.BytesIO()
torch.jit.save(m._c, buffer)
buffer.seek(0)
m2 = torch.jit.load(buffer)
# Check if frozen module looks as below:
# module m {
# attributes {
# tt = ...
# }
# ...
# }
self.assertFalse(m._c.hasattr('a'))
self.assertFalse(m._c.hasattr('b'))
self.assertFalse(m._c.hasattr('c'))
self.assertFalse(m._c.hasattr('d'))
self.assertFalse(m._c.hasattr('e'))
self.assertFalse(m._c.hasattr('f'))
self.assertFalse(m._c.hasattr('g'))
self.assertFalse(m._c.hasattr('h'))
self.assertFalse(m._c.hasattr('t'))
self.assertFalse(m._c.hasattr('ts'))
self.assertFalse(m._c.hasattr('tt'))
output_f = m.forward(input)
self.assertFalse(m2._c.hasattr('a'))
self.assertFalse(m2._c.hasattr('b'))
self.assertFalse(m2._c.hasattr('c'))
self.assertFalse(m2._c.hasattr('c2'))
self.assertFalse(m2._c.hasattr('d'))
self.assertFalse(m2._c.hasattr('e'))
self.assertFalse(m2._c.hasattr('f'))
self.assertFalse(m2._c.hasattr('f2'))
self.assertFalse(m2._c.hasattr('g'))
self.assertFalse(m2._c.hasattr('h'))
self.assertFalse(m2._c.hasattr('h2'))
self.assertFalse(m2._c.hasattr('t'))
self.assertFalse(m2._c.hasattr('ts'))
self.assertFalse(m2._c.hasattr('tt'))
output_f = m2.forward(input)
self.assertEqual(output_s, output_f)

def test_freeze_module_with_submodule(self):
Expand Down Expand Up @@ -790,8 +802,6 @@ def forward(self, x):
FileCheck().check_not('GetAttr[name=') \
.run(m._c._get_method('forward').graph)



def test_freeze_module_detach_gradient(self):
mod = nn.Conv2d(8, 3, 4, 2, 1)
self.assertTrue(mod.weight.requires_grad)
Expand Down
19 changes: 19 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3169,6 +3169,25 @@ def forward(self, x, y):
foo_loaded = torch.jit.load(buffer)
self.assertExpected(foo_loaded.forward.code)

@unittest.skip("temporarily disable the test for fwd compatibility")
def test_non_ascii_string(self):
class Foo(torch.jit.ScriptModule):
def __init__(self):
super(Foo, self).__init__()
self.a = "Over \u0e55\u0e57 57"

@torch.jit.script_method
def forward(self, x, y):
return self.a + "hi\xA1"

foo = Foo()
buffer = io.BytesIO()
torch.jit.save(foo, buffer)

buffer.seek(0)
foo_loaded = torch.jit.load(buffer)
self.assertExpected(foo_loaded.forward.code)

def test_function_default_values(self):
outer_var = torch.tensor(20)
outer_var2 = torch.tensor(30)
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/jit_log.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ bool is_enabled(const char* cfname, JitLoggingLevels level) {
// a dummy function to give to PythonPrint
std::string log_function(const std::shared_ptr<torch::jit::Graph>& graph) {
torch::jit::GraphFunction func("source_dump", graph, nullptr);
std::vector<at::Tensor> tensors;
std::vector<at::IValue> constants;
std::vector<c10::NamedTypePtr> deps;
PythonPrint pp(tensors, deps);
PythonPrint pp(constants, deps);
pp.printFunction(func);
return pp.str();
}
Expand Down
32 changes: 16 additions & 16 deletions torch/csrc/jit/python/script_init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -959,23 +959,23 @@ void initJitScriptBindings(PyObject* module) {
.def_property_readonly(
"code",
[](Module& self) {
std::vector<at::Tensor> tensors;
std::vector<at::IValue> constants;
std::vector<c10::NamedTypePtr> deps;
PythonPrint pp(tensors, deps);
PythonPrint pp(constants, deps);
pp.printNamedType(self.type());
return pp.str();
})
.def_property_readonly(
"code_with_constants",
[](Module& self) {
std::vector<at::Tensor> tensors;
std::vector<at::IValue> constants;
std::vector<c10::NamedTypePtr> deps;
PythonPrint pp(tensors, deps);
PythonPrint pp(constants, deps);
pp.printNamedType(self.type());
std::map<std::string, at::Tensor> consts;
std::map<std::string, at::IValue> consts;
int i = 0;
for (auto const& tensor : tensors) {
consts["c" + std::to_string(i)] = tensor;
for (auto const& constant : constants) {
consts["c" + std::to_string(i)] = constant;
i += 1;
}
return std::make_tuple(pp.str(), consts);
Expand Down Expand Up @@ -1082,10 +1082,10 @@ void initJitScriptBindings(PyObject* module) {
.def_property_readonly(
"code",
[](const StrongFunctionPtr& self) {
std::vector<at::Tensor> tensors;
std::vector<at::IValue> constants;
std::vector<c10::NamedTypePtr> deps;

PythonPrint pp(tensors, deps);
PythonPrint pp(constants, deps);
pp.printFunction(*self.function_);
return pp.str();
})
Expand Down Expand Up @@ -1127,21 +1127,21 @@ void initJitScriptBindings(PyObject* module) {
.def_property_readonly(
"code",
[](Method& self) {
std::vector<at::Tensor> tensors;
std::vector<at::IValue> constants;
std::vector<c10::NamedTypePtr> deps;
PythonPrint pp(tensors, deps);
PythonPrint pp(constants, deps);
pp.printMethod(self.function());
return pp.str();
})
.def_property_readonly("code_with_constants", [](Method& self) {
std::vector<at::Tensor> tensors;
std::vector<at::IValue> constants;
std::vector<c10::NamedTypePtr> deps;
PythonPrint pp(tensors, deps);
PythonPrint pp(constants, deps);
pp.printMethod(self.function());
std::map<std::string, at::Tensor> consts;
std::map<std::string, at::IValue> consts;
int i = 0;
for (auto const& tensor : tensors) {
consts["c" + std::to_string(i)] = tensor;
for (auto const& constant : constants) {
consts["c" + std::to_string(i)] = constant;
i += 1;
}
return std::make_tuple(pp.str(), consts);
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/serialization/export_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ class ScriptModuleSerializer {
}

caffe2::serialize::PyTorchStreamWriter writer_;
std::vector<at::Tensor> constant_table_;
std::vector<at::IValue> constant_table_;
std::unordered_set<c10::NamedTypePtr> converted_types_;
std::vector<c10::NamedTypePtr> class_deps_;
TypeNameUniquer type_name_uniquer_;
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/serialization/import.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class ScriptModuleDeserializer final {
std::shared_ptr<CompilationUnit> compilation_unit_;
std::unique_ptr<PyTorchStreamReader> reader_;
c10::optional<at::Device> device_;
std::vector<at::Tensor> constants_table_;
std::vector<at::IValue> constants_table_;
SourceImporter source_importer_;
std::string export_prefix_ = "code/";
};
Expand Down Expand Up @@ -264,7 +264,7 @@ Module ScriptModuleDeserializer::deserialize(
}
auto tuple = readArchive("constants").toTuple();
for (auto constant : tuple->elements()) {
constants_table_.push_back(constant.toTensor());
constants_table_.push_back(constant.toIValue());
}
auto m = Module(readArchive("data").toObject());
rewriteQuantizedConvForBC(m);
Expand Down

0 comments on commit 35f059c

Please sign in to comment.